mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat-agent-mask
This commit is contained in:
commit
79c4baa7c2
|
|
@ -1,5 +1,5 @@
|
|||
---
|
||||
name: Dify Frontend Testing
|
||||
name: frontend-testing
|
||||
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
|
||||
---
|
||||
|
||||
|
|
@ -178,7 +178,7 @@ Process in this order for multi-file testing:
|
|||
- **500+ lines**: Consider splitting before testing
|
||||
- **Many dependencies**: Extract logic into hooks first
|
||||
|
||||
> 📖 See `guides/workflow.md` for complete workflow details and todo list format.
|
||||
> 📖 See `references/workflow.md` for complete workflow details and todo list format.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
|
|
@ -289,17 +289,18 @@ For each test file generated, aim for:
|
|||
- ✅ **>95%** branch coverage
|
||||
- ✅ **>95%** line coverage
|
||||
|
||||
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`.
|
||||
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
|
||||
|
||||
## Detailed Guides
|
||||
|
||||
For more detailed information, refer to:
|
||||
|
||||
- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
|
||||
- `guides/mocking.md` - Mock patterns and best practices
|
||||
- `guides/async-testing.md` - Async operations and API calls
|
||||
- `guides/domain-components.md` - Workflow, Dataset, Configuration testing
|
||||
- `guides/common-patterns.md` - Frequently used testing patterns
|
||||
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
|
||||
- `references/mocking.md` - Mock patterns and best practices
|
||||
- `references/async-testing.md` - Async operations and API calls
|
||||
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
|
||||
- `references/common-patterns.md` - Frequently used testing patterns
|
||||
- `references/checklist.md` - Test generation checklist and validation steps
|
||||
|
||||
## Authoritative References
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
../.claude/skills
|
||||
|
|
@ -6,7 +6,7 @@ cd web && pnpm install
|
|||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
|
|
|||
|
|
@ -6,6 +6,12 @@
|
|||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# CODEOWNERS file
|
||||
.github/CODEOWNERS @laipz8200 @crazywoola
|
||||
|
||||
# Docs
|
||||
docs/ @crazywoola
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
api/ @QuantumGhost
|
||||
|
||||
|
|
@ -116,11 +122,17 @@ api/controllers/console/feature.py @GarfieldDai @GareArc
|
|||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||
|
||||
# Backend - Database Migrations
|
||||
api/migrations/ @snakevash @laipz8200
|
||||
api/migrations/ @snakevash @laipz8200 @MRZHUH
|
||||
|
||||
# Backend - Vector DB Middleware
|
||||
api/configs/middleware/vdb/* @JohnJyong
|
||||
|
||||
# Frontend
|
||||
web/ @iamjoel
|
||||
|
||||
# Frontend - Web Tests
|
||||
.github/workflows/web-tests.yml @iamjoel
|
||||
|
||||
# Frontend - App - Orchestration
|
||||
web/app/components/workflow/ @iamjoel @zxhlyh
|
||||
web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||
|
|
@ -192,6 +204,7 @@ web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
|||
web/app/signin/ @douxc @iamjoel
|
||||
web/app/signup/ @douxc @iamjoel
|
||||
web/app/reset-password/ @douxc @iamjoel
|
||||
|
||||
web/app/install/ @douxc @iamjoel
|
||||
web/app/init/ @douxc @iamjoel
|
||||
web/app/forgot-password/ @douxc @iamjoel
|
||||
|
|
@ -232,3 +245,6 @@ web/app/education-apply/ @iamjoel @zxhlyh
|
|||
|
||||
# Frontend - Workspace
|
||||
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||
|
||||
# Docker
|
||||
docker/* @laipz8200
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ jobs:
|
|||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx --python 3.13 mdformat . --exclude ".claude/skills/**"
|
||||
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
|
|
@ -79,7 +79,7 @@ jobs:
|
|||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
working-directory: ./web
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ jobs:
|
|||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ jobs:
|
|||
with:
|
||||
node-version: 'lts/*'
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
|
|
@ -21,14 +22,7 @@ jobs:
|
|||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
- name: Install pnpm
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
|
|
@ -36,23 +30,355 @@ jobs:
|
|||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Restore Jest cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: web/.cache/jest
|
||||
key: ${{ runner.os }}-jest-${{ hashFiles('web/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-jest-
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Check i18n types synchronization
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run check:i18n-types
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm test
|
||||
run: |
|
||||
pnpm exec jest \
|
||||
--ci \
|
||||
--maxWorkers=100% \
|
||||
--coverage \
|
||||
--passWithNoTests
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
id: coverage-summary
|
||||
run: |
|
||||
set -eo pipefail
|
||||
|
||||
COVERAGE_FILE="coverage/coverage-final.json"
|
||||
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
|
||||
|
||||
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
let libCoverage = null;
|
||||
|
||||
try {
|
||||
libCoverage = require('istanbul-lib-coverage');
|
||||
} catch (error) {
|
||||
libCoverage = null;
|
||||
}
|
||||
|
||||
const summaryPath = path.join('coverage', 'coverage-summary.json');
|
||||
const finalPath = path.join('coverage', 'coverage-final.json');
|
||||
|
||||
const hasSummary = fs.existsSync(summaryPath);
|
||||
const hasFinal = fs.existsSync(finalPath);
|
||||
|
||||
if (!hasSummary && !hasFinal) {
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('No coverage data found.');
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const summary = hasSummary
|
||||
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
|
||||
: null;
|
||||
const coverage = hasFinal
|
||||
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
|
||||
: null;
|
||||
|
||||
const getLineCoverageFromStatements = (statementMap, statementHits) => {
|
||||
const lineHits = {};
|
||||
|
||||
if (!statementMap || !statementHits) {
|
||||
return lineHits;
|
||||
}
|
||||
|
||||
Object.entries(statementMap).forEach(([key, statement]) => {
|
||||
const line = statement?.start?.line;
|
||||
if (!line) {
|
||||
return;
|
||||
}
|
||||
const hits = statementHits[key] ?? 0;
|
||||
const previous = lineHits[line];
|
||||
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
|
||||
});
|
||||
|
||||
return lineHits;
|
||||
};
|
||||
|
||||
const getFileCoverage = (entry) => (
|
||||
libCoverage ? libCoverage.createFileCoverage(entry) : null
|
||||
);
|
||||
|
||||
const getLineHits = (entry, fileCoverage) => {
|
||||
const lineHits = entry.l ?? {};
|
||||
if (Object.keys(lineHits).length > 0) {
|
||||
return lineHits;
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getLineCoverage();
|
||||
}
|
||||
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
|
||||
};
|
||||
|
||||
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
|
||||
if (lineHits && Object.keys(lineHits).length > 0) {
|
||||
return Object.entries(lineHits)
|
||||
.filter(([, count]) => count === 0)
|
||||
.map(([line]) => Number(line))
|
||||
.sort((a, b) => a - b);
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getUncoveredLines();
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const totals = {
|
||||
lines: { covered: 0, total: 0 },
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
};
|
||||
const fileSummaries = [];
|
||||
|
||||
if (summary) {
|
||||
const totalEntry = summary.total ?? {};
|
||||
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
|
||||
if (totalEntry[key]) {
|
||||
totals[key].covered = totalEntry[key].covered ?? 0;
|
||||
totals[key].total = totalEntry[key].total ?? 0;
|
||||
}
|
||||
});
|
||||
|
||||
Object.entries(summary)
|
||||
.filter(([file]) => file !== 'total')
|
||||
.forEach(([file, data]) => {
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
|
||||
lines: {
|
||||
covered: data.lines?.covered ?? 0,
|
||||
total: data.lines?.total ?? 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
} else if (coverage) {
|
||||
Object.entries(coverage).forEach(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
totals.lines.total += lineTotal;
|
||||
totals.lines.covered += lineCovered;
|
||||
totals.statements.total += statementTotal;
|
||||
totals.statements.covered += statementCovered;
|
||||
totals.branches.total += branchTotal;
|
||||
totals.branches.covered += branchCovered;
|
||||
totals.functions.total += functionTotal;
|
||||
totals.functions.covered += functionCovered;
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
|
||||
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
|
||||
lines: {
|
||||
covered: lineCovered || statementCovered,
|
||||
total: lineTotal || statementTotal,
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
|
||||
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('| Metric | Coverage | Covered / Total |');
|
||||
console.log('|--------|----------|-----------------|');
|
||||
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
|
||||
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
|
||||
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
|
||||
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>File coverage (lowest lines first)</summary>');
|
||||
console.log('');
|
||||
console.log('```');
|
||||
fileSummaries
|
||||
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
|
||||
.slice(0, 25)
|
||||
.forEach(({ file, pct, lines }) => {
|
||||
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
|
||||
});
|
||||
console.log('```');
|
||||
console.log('</details>');
|
||||
|
||||
if (coverage) {
|
||||
const pctValue = (covered, tot) => {
|
||||
if (tot === 0) {
|
||||
return '0';
|
||||
}
|
||||
return ((covered / tot) * 100)
|
||||
.toFixed(2)
|
||||
.replace(/\.?0+$/, '');
|
||||
};
|
||||
|
||||
const formatLineRanges = (lines) => {
|
||||
if (lines.length === 0) {
|
||||
return '';
|
||||
}
|
||||
const ranges = [];
|
||||
let start = lines[0];
|
||||
let end = lines[0];
|
||||
|
||||
for (let i = 1; i < lines.length; i += 1) {
|
||||
const current = lines[i];
|
||||
if (current === end + 1) {
|
||||
end = current;
|
||||
continue;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
start = current;
|
||||
end = current;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
return ranges.join(',');
|
||||
};
|
||||
|
||||
const tableTotals = {
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
lines: { covered: 0, total: 0 },
|
||||
};
|
||||
const tableRows = Object.entries(coverage)
|
||||
.map(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
tableTotals.lines.total += lineTotal;
|
||||
tableTotals.lines.covered += lineCovered;
|
||||
tableTotals.statements.total += statementTotal;
|
||||
tableTotals.statements.covered += statementCovered;
|
||||
tableTotals.branches.total += branchTotal;
|
||||
tableTotals.branches.covered += branchCovered;
|
||||
tableTotals.functions.total += functionTotal;
|
||||
tableTotals.functions.covered += functionCovered;
|
||||
|
||||
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
|
||||
|
||||
const filePath = entry.path ?? file;
|
||||
const relativePath = path.isAbsolute(filePath)
|
||||
? path.relative(process.cwd(), filePath)
|
||||
: filePath;
|
||||
|
||||
return {
|
||||
file: relativePath || file,
|
||||
statements: pctValue(statementCovered, statementTotal),
|
||||
branches: pctValue(branchCovered, branchTotal),
|
||||
functions: pctValue(functionCovered, functionTotal),
|
||||
lines: pctValue(lineCovered, lineTotal),
|
||||
uncovered: formatLineRanges(uncoveredLines),
|
||||
};
|
||||
})
|
||||
.sort((a, b) => a.file.localeCompare(b.file));
|
||||
|
||||
const columns = [
|
||||
{ key: 'file', header: 'File', align: 'left' },
|
||||
{ key: 'statements', header: '% Stmts', align: 'right' },
|
||||
{ key: 'branches', header: '% Branch', align: 'right' },
|
||||
{ key: 'functions', header: '% Funcs', align: 'right' },
|
||||
{ key: 'lines', header: '% Lines', align: 'right' },
|
||||
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
|
||||
];
|
||||
|
||||
const allFilesRow = {
|
||||
file: 'All files',
|
||||
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
|
||||
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
|
||||
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
|
||||
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
|
||||
uncovered: '',
|
||||
};
|
||||
|
||||
const rowsForOutput = [allFilesRow, ...tableRows];
|
||||
const formatRow = (row) => `| ${columns
|
||||
.map(({ key }) => String(row[key] ?? ''))
|
||||
.join(' | ')} |`;
|
||||
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
|
||||
const dividerRow = `| ${columns
|
||||
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
|
||||
.join(' | ')} |`;
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>Jest coverage table</summary>');
|
||||
console.log('');
|
||||
console.log(headerRow);
|
||||
console.log(dividerRow);
|
||||
rowsForOutput.forEach((row) => console.log(formatRow(row)));
|
||||
console.log('</details>');
|
||||
}
|
||||
NODE
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
retention-days: 30
|
||||
if-no-files-found: error
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@
|
|||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
|
|
|||
|
|
@ -690,3 +690,8 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
|
|||
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
|
||||
# Maximum number of concurrent annotation import tasks per tenant
|
||||
ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||
|
||||
# Sandbox expired records clean configuration
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@
|
|||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
|
|
|||
|
|
@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
|
|||
|
||||
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||
default=300.0,
|
||||
default=600.0,
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
|
@ -1270,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
|
||||
description="Graceful period in days for sandbox records clean after subscription expiration",
|
||||
default=21,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Maximum number of records to process in each batch",
|
||||
default=1000,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Retention days for sandbox expired workflow_run records and message records",
|
||||
default=30,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
|
|
@ -1295,6 +1310,7 @@ class FeatureConfig(
|
|||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SandboxExpiredRecordsCleanConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel):
|
|||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ from .. import console_ns
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseModel):
|
||||
class CompletionMessageExplorePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
|
@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel):
|
|||
raise ValueError("must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
|
|
@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
|||
endpoint="installed_app_completion",
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
|
||||
payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
|
|
@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
|
|
@ -24,7 +24,7 @@ from .. import console_ns
|
|||
|
||||
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
pinned: bool | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, inputs, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
|
|
@ -18,6 +19,15 @@ from services.account_service import TenantService
|
|||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class InstalledAppCreatePayload(BaseModel):
|
||||
app_id: str
|
||||
|
||||
|
||||
class InstalledAppUpdatePayload(BaseModel):
|
||||
is_pinned: bool | None = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
||||
args = parser.parse_args()
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
|
||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
|
||||
if recommended_app is None:
|
||||
raise NotFound("App not found")
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
||||
app = db.session.query(App).where(App.id == payload.app_id).first()
|
||||
|
||||
if app is None:
|
||||
raise NotFound("App not found")
|
||||
raise NotFound("App entity not found")
|
||||
|
||||
if not app.is_public:
|
||||
raise Forbidden("You can't install a non-public app")
|
||||
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
|
||||
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource):
|
|||
recommended_app.install_count += 1
|
||||
|
||||
new_installed_app = InstalledApp(
|
||||
app_id=args["app_id"],
|
||||
app_id=payload.app_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_owner_tenant_id=app.tenant_id,
|
||||
is_pinned=False,
|
||||
|
|
@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource):
|
|||
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
||||
|
||||
def patch(self, installed_app):
|
||||
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
|
||||
args = parser.parse_args()
|
||||
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
commit_args = False
|
||||
if "is_pinned" in args:
|
||||
installed_app.is_pinned = args["is_pinned"]
|
||||
if payload.is_pinned is not None:
|
||||
installed_app.is_pinned = payload.is_pinned
|
||||
commit_args = True
|
||||
|
||||
if commit_args:
|
||||
|
|
|
|||
|
|
@ -1,31 +1,40 @@
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
parser_tags = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
class TagBindingPayload(BaseModel):
|
||||
tag_ids: list[str] = Field(description="Tag IDs to bind")
|
||||
target_id: str = Field(description="Target ID to bind tags to")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
class TagBindingRemovePayload(BaseModel):
|
||||
tag_id: str = Field(description="Tag ID to remove")
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -43,7 +52,7 @@ class TagListApi(Resource):
|
|||
|
||||
return tags, 200
|
||||
|
||||
@console_ns.expect(parser_tags)
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -53,22 +62,17 @@ class TagListApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_tags.parse_args()
|
||||
tag = TagService.save_tags(args)
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
parser_tag_id = reqparse.RequestParser().add_argument(
|
||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@console_ns.expect(parser_tag_id)
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -79,8 +83,8 @@ class TagUpdateDeleteApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_tag_id.parse_args()
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
|
|
@ -100,17 +104,9 @@ class TagUpdateDeleteApi(Resource):
|
|||
return 204
|
||||
|
||||
|
||||
parser_create = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@console_ns.expect(parser_create)
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -120,23 +116,15 @@ class TagBindingCreateApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_create.parse_args()
|
||||
TagService.save_tag_binding(args)
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_remove = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@console_ns.expect(parser_remove)
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -146,7 +134,7 @@ class TagBindingDeleteApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_remove.parse_args()
|
||||
TagService.delete_tag_binding(args)
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from controllers.console.wraps import (
|
|||
setup_required,
|
||||
)
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
|
|
@ -944,7 +945,7 @@ class ToolProviderMCPApi(Resource):
|
|||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# Create provider
|
||||
# Create provider in transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
|
|
@ -960,7 +961,11 @@ class ToolProviderMCPApi(Resource):
|
|||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@console_ns.expect(parser_mcp_put)
|
||||
@setup_required
|
||||
|
|
@ -972,17 +977,23 @@ class ToolProviderMCPApi(Resource):
|
|||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
||||
validation_result = None
|
||||
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
|
||||
validation_data = None
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
validation_result = service.validate_server_url_change(
|
||||
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
||||
validation_data = service.get_provider_for_url_validation(
|
||||
tenant_id=current_tenant_id, provider_id=args["provider_id"]
|
||||
)
|
||||
|
||||
# No need to check for errors here, exceptions will be raised directly
|
||||
# Step 2: Perform URL validation with network I/O OUTSIDE of any database session
|
||||
# This prevents holding database locks during potentially slow network operations
|
||||
validation_result = MCPToolManageService.validate_server_url_standalone(
|
||||
tenant_id=current_tenant_id,
|
||||
new_server_url=args["server_url"],
|
||||
validation_data=validation_data,
|
||||
)
|
||||
|
||||
# Step 2: Perform database update in a transaction
|
||||
# Step 3: Perform database update in a transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
|
|
@ -999,7 +1010,11 @@ class ToolProviderMCPApi(Resource):
|
|||
authentication=authentication,
|
||||
validation_result=validation_result,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(parser_mcp_delete)
|
||||
@setup_required
|
||||
|
|
@ -1012,7 +1027,11 @@ class ToolProviderMCPApi(Resource):
|
|||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_auth = (
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class DatasetUpdatePayload(BaseModel):
|
|||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, field_validator
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
|
|
@ -20,6 +21,7 @@ from controllers.web.error import (
|
|||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
|
|
@ -29,6 +31,25 @@ from services.errors.audio import (
|
|||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
register_schema_models(web_ns, TextToAudioPayload)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -88,6 +109,7 @@ class AudioApi(WebApiResource):
|
|||
|
||||
@web_ns.route("/text-to-audio")
|
||||
class TextApi(WebApiResource):
|
||||
@web_ns.expect(web_ns.models[TextToAudioPayload.__name__])
|
||||
@web_ns.doc("Text to Audio")
|
||||
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
|
||||
@web_ns.doc(
|
||||
|
|
@ -102,18 +124,11 @@ class TextApi(WebApiResource):
|
|||
def post(self, app_model: App, end_user):
|
||||
"""Convert text to audio"""
|
||||
try:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, required=False, location="json")
|
||||
.add_argument("voice", type=str, location="json")
|
||||
.add_argument("text", type=str, location="json")
|
||||
.add_argument("streaming", type=bool, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = payload.message_id
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask_restx import reqparse
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
|
|
@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the completion")
|
||||
query: str = Field(default="", description="Query text for completion")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
|
||||
response_mode: Literal["blocking", "streaming"] | None = Field(
|
||||
default=None, description="Response mode: blocking or streaming"
|
||||
)
|
||||
retriever_from: str = Field(default="web_app", description="Source of retriever")
|
||||
|
||||
|
||||
class ChatMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the chat")
|
||||
query: str = Field(description="User query/message")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
|
||||
response_mode: Literal["blocking", "streaming"] | None = Field(
|
||||
default=None, description="Response mode: blocking or streaming"
|
||||
)
|
||||
conversation_id: str | None = Field(default=None, description="Conversation ID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message ID")
|
||||
retriever_from: str = Field(default="web_app", description="Source of retriever")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
@web_ns.route("/completion-messages")
|
||||
class CompletionApi(WebApiResource):
|
||||
@web_ns.doc("Create Completion Message")
|
||||
@web_ns.doc(description="Create a completion message for text generation applications.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
|
||||
"query": {"description": "Query text for completion", "type": "string", "required": False},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.expect(web_ns.models[CompletionMessagePayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
|
|
@ -67,18 +88,10 @@ class CompletionApi(WebApiResource):
|
|||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, location="json", default="")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||
)
|
||||
payload = CompletionMessagePayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
|
|
@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource):
|
|||
class ChatApi(WebApiResource):
|
||||
@web_ns.doc("Create Chat Message")
|
||||
@web_ns.doc(description="Create a chat message for conversational applications.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
|
||||
"query": {"description": "User query/message", "type": "string", "required": True},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
|
||||
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.expect(web_ns.models[ChatMessagePayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
|
|
@ -173,20 +171,10 @@ class ChatApi(WebApiResource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||
)
|
||||
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
|
@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
|
|||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: dict[str, Any] | None = Field(default=None)
|
||||
json_schema: str | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
|
|||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
def validate_json_schema(cls, schema: str | None) -> str | None:
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
Draft7Validator.check_schema(schema)
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"invalid json_schema value {schema}")
|
||||
|
||||
try:
|
||||
Draft7Validator.check_schema(json_schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
|
|
@ -175,6 +176,13 @@ class BaseAppGenerator:
|
|||
value = True
|
||||
elif value == 0:
|
||||
value = False
|
||||
case VariableEntityType.JSON_OBJECT:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a string")
|
||||
try:
|
||||
json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
|
|
|||
|
|
@ -342,9 +342,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
event_type=event_type,
|
||||
)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from threading import Thread
|
|||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -54,6 +54,20 @@ class MessageCycleManager:
|
|||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
self._message_has_file: set[str] = set()
|
||||
|
||||
def get_message_event_type(self, message_id: str) -> StreamEvent:
|
||||
if message_id in self._message_has_file:
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
|
||||
|
||||
if has_file:
|
||||
self._message_has_file.add(message_id)
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
return StreamEvent.MESSAGE
|
||||
|
||||
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
|
||||
"""
|
||||
|
|
@ -214,7 +228,11 @@ class MessageCycleManager:
|
|||
return None
|
||||
|
||||
def message_to_stream_response(
|
||||
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
|
||||
self,
|
||||
answer: str,
|
||||
message_id: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
event_type: StreamEvent | None = None,
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
|
|
@ -222,16 +240,12 @@ class MessageCycleManager:
|
|||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
|
||||
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
|
||||
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
from_variable_selector=from_variable_selector,
|
||||
event=event_type,
|
||||
event=event_type or StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
|
|
|
|||
|
|
@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
|
|||
"""
|
||||
Build a list of URLs to try for Protected Resource Metadata discovery.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
|
||||
Priority order:
|
||||
1. URL from WWW-Authenticate header (if provided)
|
||||
2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
|
||||
3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
|
||||
"""
|
||||
urls = []
|
||||
|
||||
|
|
@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
|
|||
# Fallback: construct from server URL
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||||
if fallback_url not in urls:
|
||||
urls.append(fallback_url)
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
# Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
|
||||
if path:
|
||||
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
|
||||
if path_url not in urls:
|
||||
urls.append(path_url)
|
||||
|
||||
# Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
|
||||
root_url = f"{base_url}/.well-known/oauth-protected-resource"
|
||||
if root_url not in urls:
|
||||
urls.append(root_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
|
@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st
|
|||
|
||||
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
|
||||
|
||||
Per RFC 8414 section 3:
|
||||
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
|
||||
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
|
||||
|
||||
Example:
|
||||
- issuer: https://example.com/oauth
|
||||
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
|
||||
Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
|
||||
- OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
|
||||
- OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
|
||||
- OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
|
||||
- OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
|
||||
- OpenID Connect at root: https://example.com/.well-known/openid-configuration
|
||||
"""
|
||||
urls = []
|
||||
base_url = auth_server_url or server_url
|
||||
|
||||
parsed = urlparse(base_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/") # Remove trailing slash
|
||||
path = parsed.path.rstrip("/")
|
||||
# OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
|
||||
urls.append(f"{base}/.well-known/oauth-authorization-server")
|
||||
|
||||
# Try OpenID Connect discovery first (more common)
|
||||
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
|
||||
# OpenID Connect Discovery at root
|
||||
urls.append(f"{base}/.well-known/openid-configuration")
|
||||
|
||||
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
|
||||
# Include the path component if present in the issuer URL
|
||||
if path:
|
||||
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
|
||||
else:
|
||||
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
|
||||
# OpenID Connect Discovery with path insertion
|
||||
urls.append(f"{base}/.well-known/openid-configuration{path}")
|
||||
|
||||
# OpenID Connect Discovery path appending
|
||||
urls.append(f"{base}{path}/.well-known/openid-configuration")
|
||||
|
||||
# OAuth 2.0 Authorization Server Metadata with path insertion
|
||||
urls.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
||||
|
||||
return urls
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class MCPClient:
|
|||
try:
|
||||
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
|
||||
self.connect_server(sse_client, "sse")
|
||||
except MCPConnectionError:
|
||||
except (MCPConnectionError, ValueError):
|
||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
|
|
|
|||
|
|
@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model
|
|||
|
||||
- Model provider display
|
||||
|
||||

|
||||
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
|
||||
|
||||
- Selectable model list display
|
||||
|
||||

|
||||
|
||||
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
|
||||
|
||||
In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
|
||||
|
||||

|
||||
|
||||
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
|
||||
In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
|
||||
|
||||
- Provider/model credential authentication
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
|
||||
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
|
||||
|
||||
## Structure
|
||||
|
||||

|
||||
|
||||
Model Runtime is divided into three layers:
|
||||
|
||||
- The outermost layer is the factory method
|
||||
|
|
@ -60,9 +46,6 @@ Model Runtime is divided into three layers:
|
|||
|
||||
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
|
||||
|
||||
## Next Steps
|
||||
## Documentation
|
||||
|
||||
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
|
||||
- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
|
||||
- View YAML configuration rules: [Link](./docs/en_US/schema.md)
|
||||
- Implement interface methods: [Link](./docs/en_US/interfaces.md)
|
||||
For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).
|
||||
|
|
|
|||
|
|
@ -18,34 +18,20 @@
|
|||
|
||||
- 模型供应商展示
|
||||
|
||||

|
||||
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
|
||||
|
||||
- 可选择的模型列表展示
|
||||
|
||||

|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
|
||||
|
||||

|
||||
|
||||
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
|
||||
|
||||
- 供应商/模型凭据鉴权
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
|
||||
|
||||
## 结构
|
||||
|
||||

|
||||
|
||||
Model Runtime 分三层:
|
||||
|
||||
- 最外层为工厂方法
|
||||
|
|
@ -59,8 +45,7 @@ Model Runtime 分三层:
|
|||
对于供应商/模型凭据,有两种情况
|
||||
|
||||
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
|
||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||

|
||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||
|
||||
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
|
||||
|
||||
|
|
@ -74,20 +59,6 @@ Model Runtime 分三层:
|
|||
|
||||
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
|
||||
|
||||
## 下一步
|
||||
## 文档
|
||||
|
||||
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
|
||||
|
||||
当添加后,这里将会出现一个新的供应商
|
||||
|
||||

|
||||
|
||||
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B)
|
||||
|
||||
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
|
||||
|
||||

|
||||
|
||||
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
|
||||
|
||||
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
|
||||
有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from core.trigger.errors import (
|
|||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
_plugin_daemon_timeout_config = cast(
|
||||
float | httpx.Timeout | None,
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
|
||||
)
|
||||
plugin_daemon_request_timeout: httpx.Timeout | None
|
||||
if _plugin_daemon_timeout_config is None:
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ class WordExtractor(BaseExtractor):
|
|||
def _extract_images_from_docx(self, doc):
|
||||
image_count = 0
|
||||
image_map = {}
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
for r_id, rel in doc.part.rels.items():
|
||||
if "image" in rel.target_ref:
|
||||
|
|
@ -121,8 +122,7 @@ class WordExtractor(BaseExtractor):
|
|||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
# Use r_id as key for external images since target_part is undefined
|
||||
image_map[r_id] = f""
|
||||
image_map[r_id] = f""
|
||||
else:
|
||||
image_ext = rel.target_ref.split(".")[-1]
|
||||
if image_ext is None:
|
||||
|
|
@ -150,10 +150,7 @@ class WordExtractor(BaseExtractor):
|
|||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
# Use target_part as key for internal images
|
||||
image_map[rel.target_part] = (
|
||||
f""
|
||||
)
|
||||
image_map[rel.target_part] = f""
|
||||
db.session.commit()
|
||||
return image_map
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||
def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._fixed_separator = fixed_separator
|
||||
self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
|
||||
self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
|
|
@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||
splits = re.split(r" +", text)
|
||||
else:
|
||||
splits = text.split(separator)
|
||||
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
|
||||
if self._keep_separator:
|
||||
splits = [s + separator for s in splits[:-1]] + splits[-1:]
|
||||
else:
|
||||
splits = list(text)
|
||||
if separator == "\n":
|
||||
|
|
@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||
splits = [s for s in splits if (s not in {"", "\n"})]
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = separator if self._keep_separator else ""
|
||||
_separator = "" if self._keep_separator else separator
|
||||
s_lens = self._length_function(splits)
|
||||
if separator != "":
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
|
|
|
|||
|
|
@ -86,6 +86,11 @@ class Executor:
|
|||
node_data.authorization.config.api_key = variable_pool.convert_template(
|
||||
node_data.authorization.config.api_key
|
||||
).text
|
||||
# Validate that API key is not empty after template conversion
|
||||
if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip():
|
||||
raise AuthorizationConfigError(
|
||||
"API key is required for authorization but was empty. Please provide a valid API key."
|
||||
)
|
||||
|
||||
self.url = node_data.url
|
||||
self.method = node_data.method
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
|
|
@ -42,15 +43,25 @@ class StartNode(Node[StartNodeData]):
|
|||
if value is None and variable.required:
|
||||
raise ValueError(f"{key} is required in input form")
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"{key} must be a JSON object")
|
||||
|
||||
schema = variable.json_schema
|
||||
if not schema:
|
||||
continue
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
try:
|
||||
Draft7Validator(schema).validate(value)
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{schema} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
json_value = json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{value} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
Draft7Validator(json_schema).validate(json_value)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
|
||||
node_inputs[key] = value
|
||||
node_inputs[key] = json_value
|
||||
|
|
|
|||
|
|
@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
|
|
@ -69,6 +69,53 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||
|
||||
elif [[ "${MODE}" == "beat" ]]; then
|
||||
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
||||
|
||||
elif [[ "${MODE}" == "job" ]]; then
|
||||
# Job mode: Run a one-time Flask command and exit
|
||||
# Pass Flask command and arguments via container args
|
||||
# Example K8s usage:
|
||||
# args:
|
||||
# - create-tenant
|
||||
# - --email
|
||||
# - admin@example.com
|
||||
#
|
||||
# Example Docker usage:
|
||||
# docker run -e MODE=job dify-api:latest create-tenant --email admin@example.com
|
||||
|
||||
if [[ $# -eq 0 ]]; then
|
||||
echo "Error: No command specified for job mode."
|
||||
echo ""
|
||||
echo "Usage examples:"
|
||||
echo " Kubernetes:"
|
||||
echo " args: [create-tenant, --email, admin@example.com]"
|
||||
echo ""
|
||||
echo " Docker:"
|
||||
echo " docker run -e MODE=job dify-api create-tenant --email admin@example.com"
|
||||
echo ""
|
||||
echo "Available commands:"
|
||||
echo " create-tenant, reset-password, reset-email, upgrade-db,"
|
||||
echo " vdb-migrate, install-plugins, and more..."
|
||||
echo ""
|
||||
echo "Run 'flask --help' to see all available commands."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Running Flask job command: flask $*"
|
||||
|
||||
# Temporarily disable exit on error to capture exit code
|
||||
set +e
|
||||
flask "$@"
|
||||
JOB_EXIT_CODE=$?
|
||||
set -e
|
||||
|
||||
if [[ ${JOB_EXIT_CODE} -eq 0 ]]; then
|
||||
echo "Job completed successfully."
|
||||
else
|
||||
echo "Job failed with exit code ${JOB_EXIT_CODE}."
|
||||
fi
|
||||
|
||||
exit ${JOB_EXIT_CODE}
|
||||
|
||||
else
|
||||
if [[ "${DEBUG}" == "true" ]]; then
|
||||
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
|
||||
|
|
|
|||
|
|
@ -87,15 +87,16 @@ class OpenDALStorage(BaseStorage):
|
|||
if not self.exists(path):
|
||||
raise FileNotFoundError("Path not found")
|
||||
|
||||
all_files = self.op.scan(path=path)
|
||||
# Use the new OpenDAL 0.46.0+ API with recursive listing
|
||||
lister = self.op.list(path, recursive=True)
|
||||
if files and directories:
|
||||
logger.debug("files and directories on %s scanned", path)
|
||||
return [f.path for f in all_files]
|
||||
return [entry.path for entry in lister]
|
||||
if files:
|
||||
logger.debug("files on %s scanned", path)
|
||||
return [f.path for f in all_files if not f.path.endswith("/")]
|
||||
return [entry.path for entry in lister if not entry.metadata.is_dir]
|
||||
elif directories:
|
||||
logger.debug("directories on %s scanned", path)
|
||||
return [f.path for f in all_files if f.path.endswith("/")]
|
||||
return [entry.path for entry in lister if entry.metadata.is_dir]
|
||||
else:
|
||||
raise ValueError("At least one of files or directories must be True")
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from collections.abc import Generator, Mapping
|
|||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
|
||||
from uuid import UUID
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
|
|
@ -119,6 +120,19 @@ def uuid_value(value: Any) -> str:
|
|||
raise ValueError(error)
|
||||
|
||||
|
||||
def normalize_uuid(value: str | UUID) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
|
||||
try:
|
||||
return uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("must be a valid UUID") from exc
|
||||
|
||||
|
||||
UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)]
|
||||
|
||||
|
||||
def alphanumeric(value: str):
|
||||
# check if the value is alphanumeric and underlined
|
||||
if re.match(r"^[a-zA-Z0-9_]+$", value):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||
from typing_extensions import TypedDict
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
|
|
@ -11,6 +15,15 @@ from extensions.ext_redis import redis_client
|
|||
from libs.helper import RateLimiter
|
||||
from models import Account, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubscriptionPlan(TypedDict):
|
||||
"""Tenant subscriptionplan information."""
|
||||
|
||||
plan: str
|
||||
expiration_date: int
|
||||
|
||||
|
||||
class BillingService:
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
|
|
@ -239,3 +252,39 @@ class BillingService:
|
|||
def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str):
|
||||
payload = {"account_id": account_id, "click_id": click_id}
|
||||
return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload)
|
||||
|
||||
@classmethod
|
||||
def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
|
||||
"""
|
||||
Bulk fetch billing subscription plan via billing API.
|
||||
Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request)
|
||||
Returns:
|
||||
Mapping of tenant_id -> {plan: str, expiration_date: int}
|
||||
"""
|
||||
results: dict[str, SubscriptionPlan] = {}
|
||||
subscription_adapter = TypeAdapter(SubscriptionPlan)
|
||||
|
||||
chunk_size = 200
|
||||
for i in range(0, len(tenant_ids), chunk_size):
|
||||
chunk = tenant_ids[i : i + chunk_size]
|
||||
try:
|
||||
resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk})
|
||||
data = resp.get("data", {})
|
||||
|
||||
for tenant_id, plan in data.items():
|
||||
subscription_plan = subscription_adapter.validate_python(plan)
|
||||
results[tenant_id] = subscription_plan
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
|
||||
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
|
||||
data = resp.get("data", [])
|
||||
tenant_whitelist = []
|
||||
for item in data:
|
||||
tenant_whitelist.append(item["tenant_id"])
|
||||
return tenant_whitelist
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel):
|
|||
description: str
|
||||
icon_info: IconInfo
|
||||
permission: str
|
||||
partial_member_list: list[str] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
yaml_content: str | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
|
|||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
|
|
@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel):
|
|||
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
||||
|
||||
|
||||
class ProviderUrlValidationData(BaseModel):
|
||||
"""Data required for URL validation, extracted from database to perform network operations outside of session"""
|
||||
|
||||
current_server_url_hash: str
|
||||
headers: dict[str, str]
|
||||
timeout: float | None
|
||||
sse_read_timeout: float | None
|
||||
|
||||
|
||||
class MCPToolManageService:
|
||||
"""Service class for managing MCP tools and providers."""
|
||||
|
||||
|
|
@ -166,9 +174,6 @@ class MCPToolManageService:
|
|||
self._session.add(mcp_tool)
|
||||
self._session.flush()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
return mcp_providers
|
||||
|
||||
|
|
@ -192,7 +197,7 @@ class MCPToolManageService:
|
|||
Update an MCP provider.
|
||||
|
||||
Args:
|
||||
validation_result: Pre-validation result from validate_server_url_change.
|
||||
validation_result: Pre-validation result from validate_server_url_standalone.
|
||||
If provided and contains reconnect_result, it will be used
|
||||
instead of performing network operations.
|
||||
"""
|
||||
|
|
@ -251,8 +256,6 @@ class MCPToolManageService:
|
|||
# Flush changes to database
|
||||
self._session.flush()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
except IntegrityError as e:
|
||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||
|
||||
|
|
@ -261,9 +264,6 @@ class MCPToolManageService:
|
|||
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
self._session.delete(mcp_tool)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
def list_providers(
|
||||
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
|
||||
) -> list[ToolProviderApiEntity]:
|
||||
|
|
@ -546,30 +546,39 @@ class MCPToolManageService:
|
|||
)
|
||||
return self.execute_auth_actions(auth_result)
|
||||
|
||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
|
||||
"""Attempt to reconnect to MCP provider with new server URL."""
|
||||
def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData:
|
||||
"""
|
||||
Get provider data required for URL validation.
|
||||
This method performs database read and should be called within a session.
|
||||
|
||||
Returns:
|
||||
ProviderUrlValidationData: Data needed for standalone URL validation
|
||||
"""
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
provider_entity = provider.to_entity()
|
||||
headers = provider_entity.headers
|
||||
return ProviderUrlValidationData(
|
||||
current_server_url_hash=provider.server_url_hash,
|
||||
headers=provider_entity.headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def validate_server_url_change(
|
||||
self, *, tenant_id: str, provider_id: str, new_server_url: str
|
||||
@staticmethod
|
||||
def validate_server_url_standalone(
|
||||
*,
|
||||
tenant_id: str,
|
||||
new_server_url: str,
|
||||
validation_data: ProviderUrlValidationData,
|
||||
) -> ServerUrlValidationResult:
|
||||
"""
|
||||
Validate server URL change by attempting to connect to the new server.
|
||||
This method should be called BEFORE update_provider to perform network operations
|
||||
outside of the database transaction.
|
||||
This method performs network operations and MUST be called OUTSIDE of any database session
|
||||
to avoid holding locks during network I/O.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID for encryption
|
||||
new_server_url: The new server URL to validate
|
||||
validation_data: Provider data obtained from get_provider_for_url_validation
|
||||
|
||||
Returns:
|
||||
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
||||
|
|
@ -579,25 +588,30 @@ class MCPToolManageService:
|
|||
return ServerUrlValidationResult(needs_validation=False)
|
||||
|
||||
# Validate URL format
|
||||
if not self._is_valid_url(new_server_url):
|
||||
parsed = urlparse(new_server_url)
|
||||
if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
# Always encrypt and hash the URL
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
||||
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
||||
|
||||
# Get current provider
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Check if URL is actually different
|
||||
if new_server_url_hash == provider.server_url_hash:
|
||||
if new_server_url_hash == validation_data.current_server_url_hash:
|
||||
# URL hasn't changed, but still return the encrypted data
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
|
||||
needs_validation=False,
|
||||
encrypted_server_url=encrypted_server_url,
|
||||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
# Perform validation by attempting to connect
|
||||
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
|
||||
# Perform network validation - this is the expensive operation that should be outside session
|
||||
reconnect_result = MCPToolManageService._reconnect_with_url(
|
||||
server_url=new_server_url,
|
||||
headers=validation_data.headers,
|
||||
timeout=validation_data.timeout,
|
||||
sse_read_timeout=validation_data.sse_read_timeout,
|
||||
)
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=True,
|
||||
validation_passed=True,
|
||||
|
|
@ -606,6 +620,38 @@ class MCPToolManageService:
|
|||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reconnect_with_url(
|
||||
*,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
timeout: float | None,
|
||||
sse_read_timeout: float | None,
|
||||
) -> ReconnectResult:
|
||||
"""
|
||||
Attempt to connect to MCP server with given URL.
|
||||
This is a static method that performs network I/O without database access.
|
||||
"""
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def _build_tool_provider_response(
|
||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||
) -> ToolProviderApiEntity:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import logging
|
|||
import time
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
|
|
@ -12,7 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
|
|||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.source import DataSourceOauthBinding
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -48,27 +47,36 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||
page_id = data_source_info["notion_page_id"]
|
||||
page_type = data_source_info["type"]
|
||||
page_edited_time = data_source_info["last_edited_time"]
|
||||
credential_id = data_source_info.get("credential_id")
|
||||
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
sa.and_(
|
||||
DataSourceOauthBinding.tenant_id == document.tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
.first()
|
||||
# Get credentials from datasource provider
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=document.tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
|
||||
if not credential:
|
||||
logger.error(
|
||||
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
|
||||
document_id,
|
||||
document.tenant_id,
|
||||
credential_id,
|
||||
)
|
||||
document.indexing_status = "error"
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
loader = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=document.tenant_id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
|
|
@ -169,13 +170,14 @@ def test_custom_authorization_header(setup_http_mock):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
|
||||
def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised."""
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
)
|
||||
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
@ -208,16 +210,13 @@ def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
|||
ssl_verify=True,
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = Executor(
|
||||
node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool
|
||||
)
|
||||
|
||||
# Get assembled headers
|
||||
headers = executor._assembling_headers()
|
||||
|
||||
# When api_key is empty, the custom header should NOT be set
|
||||
assert "X-Custom-Auth" not in headers
|
||||
# Create executor should raise AuthorizationConfigError
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
|
|
@ -305,9 +304,10 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
|
|||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
||||
"""
|
||||
Test that custom authorization doesn't set header when api_key is empty.
|
||||
This test verifies the fix for issue #23554.
|
||||
Test that custom authorization raises error when api_key is empty.
|
||||
This test verifies the fix for issue #21830.
|
||||
"""
|
||||
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
|
|
@ -333,11 +333,10 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
|||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
# Custom header should NOT be set when api_key is empty
|
||||
assert "X-Custom-Auth:" not in data
|
||||
# Should fail with AuthorizationConfigError
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "API key is required" in result.error
|
||||
assert result.error_type == "AuthorizationConfigError"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from models import Account, Tenant
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
|
@ -298,7 +300,7 @@ class TestApiToolManageService:
|
|||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
|
|
@ -364,7 +366,7 @@ class TestApiToolManageService:
|
|||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none"}
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
|
|
@ -428,21 +430,10 @@ class TestApiToolManageService:
|
|||
labels = ["test"]
|
||||
|
||||
# Act & Assert: Try to create provider with invalid schema type
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TypeAdapter(ApiProviderSchemaType).validate_python(schema_type)
|
||||
|
||||
assert "invalid schema type" in str(exc_info.value)
|
||||
assert "validation error" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
|
|
@ -464,7 +455,7 @@ class TestApiToolManageService:
|
|||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {} # Missing auth_type
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
|
|
@ -507,7 +498,7 @@ class TestApiToolManageService:
|
|||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔑"}
|
||||
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
|
||||
schema_type = "openapi"
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
|
|
|
|||
|
|
@ -1308,18 +1308,17 @@ class TestMCPToolManageService:
|
|||
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
|
||||
]
|
||||
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||
# Setup mock client
|
||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||
mock_client_instance.list_tools.return_value = mock_tools
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
result = service._reconnect_provider(
|
||||
result = MCPToolManageService._reconnect_with_url(
|
||||
server_url="https://example.com/mcp",
|
||||
provider=mcp_provider,
|
||||
headers={"X-Test": "1"},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
|
@ -1337,8 +1336,12 @@ class TestMCPToolManageService:
|
|||
assert tools_data[1]["name"] == "test_tool_2"
|
||||
|
||||
# Verify mock interactions
|
||||
provider_entity = mcp_provider.to_entity()
|
||||
mock_mcp_client.assert_called_once()
|
||||
mock_mcp_client.assert_called_once_with(
|
||||
server_url="https://example.com/mcp",
|
||||
headers={"X-Test": "1"},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
|
|
@ -1361,19 +1364,18 @@ class TestMCPToolManageService:
|
|||
)
|
||||
|
||||
# Mock MCPClient to raise authentication error
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||
from core.mcp.error import MCPAuthError
|
||||
|
||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
result = service._reconnect_provider(
|
||||
result = MCPToolManageService._reconnect_with_url(
|
||||
server_url="https://example.com/mcp",
|
||||
provider=mcp_provider,
|
||||
headers={},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
|
@ -1404,18 +1406,17 @@ class TestMCPToolManageService:
|
|||
)
|
||||
|
||||
# Mock MCPClient to raise connection error
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
||||
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||
from core.mcp.error import MCPError
|
||||
|
||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
|
||||
service._reconnect_provider(
|
||||
MCPToolManageService._reconnect_with_url(
|
||||
server_url="https://example.com/mcp",
|
||||
provider=mcp_provider,
|
||||
headers={"X-Test": "1"},
|
||||
timeout=mcp_provider.timeout,
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,420 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentMessageEvent,
|
||||
QueueErrorEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueuePingEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
EasyUITaskState,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher
|
||||
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse:
|
||||
"""Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock(spec=ChatAppGenerateEntity)
|
||||
entity.task_id = "test-task-id"
|
||||
entity.app_id = "test-app-id"
|
||||
# minimal app_config used by pipeline internals
|
||||
entity.app_config = SimpleNamespace(
|
||||
tenant_id="test-tenant-id",
|
||||
app_id="test-app-id",
|
||||
app_mode=AppMode.CHAT,
|
||||
app_model_config_dict={},
|
||||
additional_features=None,
|
||||
sensitive_word_avoidance=None,
|
||||
)
|
||||
# minimal model_conf for LLMResult init
|
||||
entity.model_conf = SimpleNamespace(
|
||||
model="test-model",
|
||||
provider_model_bundle=SimpleNamespace(model_type_instance=Mock()),
|
||||
credentials={},
|
||||
)
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
"""Create a mock queue manager."""
|
||||
manager = Mock(spec=AppQueueManager)
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_cycle_manager(self):
|
||||
"""Create a mock message cycle manager."""
|
||||
manager = Mock()
|
||||
manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse)
|
||||
manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse)
|
||||
manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse)
|
||||
manager.handle_retriever_resources = Mock()
|
||||
manager.handle_annotation_reply.return_value = None
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation(self):
|
||||
"""Create a mock conversation."""
|
||||
conversation = Mock()
|
||||
conversation.id = "test-conversation-id"
|
||||
conversation.mode = "chat"
|
||||
return conversation
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message."""
|
||||
message = Mock()
|
||||
message.id = "test-message-id"
|
||||
message.created_at = Mock()
|
||||
message.created_at.timestamp.return_value = 1234567890
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_state(self):
|
||||
"""Create a mock task state."""
|
||||
task_state = Mock(spec=EasyUITaskState)
|
||||
|
||||
# Create LLM result mock
|
||||
llm_result = Mock(spec=RuntimeLLMResult)
|
||||
llm_result.prompt_messages = []
|
||||
llm_result.message = Mock()
|
||||
llm_result.message.content = ""
|
||||
|
||||
task_state.llm_result = llm_result
|
||||
task_state.answer = ""
|
||||
|
||||
return task_state
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(
|
||||
self,
|
||||
mock_application_generate_entity,
|
||||
mock_queue_manager,
|
||||
mock_conversation,
|
||||
mock_message,
|
||||
mock_message_cycle_manager,
|
||||
mock_task_state,
|
||||
):
|
||||
"""Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies."""
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state
|
||||
):
|
||||
pipeline = EasyUIBasedGenerateTaskPipeline(
|
||||
application_generate_entity=mock_application_generate_entity,
|
||||
queue_manager=mock_queue_manager,
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
stream=True,
|
||||
)
|
||||
pipeline._message_cycle_manager = mock_message_cycle_manager
|
||||
pipeline._task_state = mock_task_state
|
||||
return pipeline
|
||||
|
||||
def test_get_message_event_type_called_once_when_first_llm_chunk_arrives(
|
||||
self, pipeline, mock_message_cycle_manager
|
||||
):
|
||||
"""Expect get_message_event_type to be called when processing the first LLM chunk event."""
|
||||
# Setup a minimal LLM chunk event
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "hi"
|
||||
chunk.prompt_messages = []
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id")
|
||||
|
||||
def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of LLM chunk events with text content."""
|
||||
# Setup
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "Hello, world!"
|
||||
chunk.prompt_messages = []
|
||||
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
|
||||
answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
assert mock_task_state.llm_result.message.content == "Hello, world!"
|
||||
|
||||
def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of LLM chunk events with list content."""
|
||||
# Setup
|
||||
text_content = Mock(spec=TextPromptMessageContent)
|
||||
text_content.data = "Hello"
|
||||
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = [text_content, " world!"]
|
||||
chunk.prompt_messages = []
|
||||
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
|
||||
answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
assert mock_task_state.llm_result.message.content == "Hello world!"
|
||||
|
||||
def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of agent message events."""
|
||||
# Setup
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "Agent response"
|
||||
|
||||
agent_message_event = Mock(spec=QueueAgentMessageEvent)
|
||||
agent_message_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = agent_message_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
# Ensure method under assertion is a mock to track calls
|
||||
pipeline._agent_message_to_stream_response = Mock(return_value=Mock())
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
# Agent messages should use _agent_message_to_stream_response
|
||||
pipeline._agent_message_to_stream_response.assert_called_once_with(
|
||||
answer="Agent response", message_id="test-message-id"
|
||||
)
|
||||
|
||||
def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of message end events."""
|
||||
# Setup
|
||||
llm_result = Mock(spec=RuntimeLLMResult)
|
||||
llm_result.message = Mock()
|
||||
llm_result.message.content = "Final response"
|
||||
|
||||
message_end_event = Mock(spec=QueueMessageEndEvent)
|
||||
message_end_event.llm_result = llm_result
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = message_end_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline._save_message = Mock()
|
||||
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
assert mock_task_state.llm_result == llm_result
|
||||
pipeline._save_message.assert_called_once()
|
||||
pipeline._message_end_to_stream_response.assert_called_once()
|
||||
|
||||
def test_error_event(self, pipeline):
|
||||
"""Test handling of error events."""
|
||||
# Setup
|
||||
error_event = Mock(spec=QueueErrorEvent)
|
||||
error_event.error = Exception("Test error")
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = error_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.handle_error = Mock(return_value=Exception("Test error"))
|
||||
pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
pipeline.handle_error.assert_called_once()
|
||||
pipeline.error_to_stream_response.assert_called_once()
|
||||
|
||||
def test_ping_event(self, pipeline):
|
||||
"""Test handling of ping events."""
|
||||
# Setup
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = ping_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
pipeline.ping_stream_response.assert_called_once()
|
||||
|
||||
def test_file_event(self, pipeline, mock_message_cycle_manager):
|
||||
"""Test handling of file events."""
|
||||
# Setup
|
||||
file_event = Mock(spec=QueueMessageFileEvent)
|
||||
file_event.message_file_id = "file-id"
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = file_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
file_response = Mock(spec=MessageFileStreamResponse)
|
||||
mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
assert responses[0] == file_response
|
||||
mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event)
|
||||
|
||||
def test_publisher_is_called_with_messages(self, pipeline):
|
||||
"""Test that publisher publishes messages when provided."""
|
||||
# Setup
|
||||
publisher = Mock(spec=AppGeneratorTTSPublisher)
|
||||
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = ping_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=publisher, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
# Called once with message and once with None at the end
|
||||
assert publisher.publish.call_count == 2
|
||||
publisher.publish.assert_any_call(mock_queue_message)
|
||||
publisher.publish.assert_any_call(None)
|
||||
|
||||
def test_trace_manager_passed_to_save_message(self, pipeline):
|
||||
"""Test that trace manager is passed to _save_message."""
|
||||
# Setup
|
||||
trace_manager = Mock(spec=TraceQueueManager)
|
||||
|
||||
message_end_event = Mock(spec=QueueMessageEndEvent)
|
||||
message_end_event.llm_result = None
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = message_end_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline._save_message = Mock()
|
||||
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager))
|
||||
|
||||
# Assert
|
||||
pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager)
|
||||
|
||||
def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling multiple events in sequence."""
|
||||
# Setup
|
||||
chunk1 = Mock()
|
||||
chunk1.delta.message.content = "Hello"
|
||||
chunk1.prompt_messages = []
|
||||
|
||||
chunk2 = Mock()
|
||||
chunk2.delta.message.content = " world!"
|
||||
chunk2.prompt_messages = []
|
||||
|
||||
llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event1.chunk = chunk1
|
||||
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
|
||||
llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event2.chunk = chunk2
|
||||
|
||||
mock_queue_messages = [
|
||||
Mock(event=llm_chunk_event1),
|
||||
Mock(event=ping_event),
|
||||
Mock(event=llm_chunk_event2),
|
||||
]
|
||||
pipeline.queue_manager.listen.return_value = mock_queue_messages
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 3
|
||||
assert mock_task_state.llm_result.message.content == "Hello world!"
|
||||
|
||||
# Verify calls to message_to_stream_response
|
||||
assert mock_message_cycle_manager.message_to_stream_response.call_count == 2
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
|
||||
answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
|
||||
answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
"""Unit tests for the message cycle manager optimization."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
|
||||
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
|
||||
|
||||
class TestMessageCycleManagerOptimization:
|
||||
"""Test cases for the message cycle manager optimization that prevents N+1 queries."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock()
|
||||
entity.task_id = "test-task-id"
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def message_cycle_manager(self, mock_application_generate_entity):
|
||||
"""Create a message cycle manager instance."""
|
||||
task_state = Mock()
|
||||
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
|
||||
|
||||
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
result = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and no message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
result = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
|
||||
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute: compute event type once, then pass to message_to_stream_response
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
|
||||
"""Test that message_to_stream_response skips database query when event_type is provided."""
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
# Execute with event_type provided
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
# Should not query database when event_type is provided
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
|
||||
"""Test message_to_stream_response with from_variable_selector parameter."""
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world",
|
||||
message_id="test-message-id",
|
||||
from_variable_selector=["var1", "var2"],
|
||||
event_type=StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.from_variable_selector == ["var1", "var2"]
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
|
||||
def test_optimization_usage_example(self, message_cycle_manager):
|
||||
"""Test the optimization pattern that should be used by callers."""
|
||||
# Step 1: Get event type once (this queries database)
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None # No files
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Should query database once
|
||||
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
|
||||
assert event_type == StreamEvent.MESSAGE
|
||||
|
||||
# Step 2: Use event_type for multiple calls (no additional queries)
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
mock_session_class.return_value.__enter__.return_value = Mock()
|
||||
|
||||
chunk1_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 1", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
chunk2_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 2", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Should not query database again
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
assert chunk1_response.event == StreamEvent.MESSAGE
|
||||
assert chunk2_response.event == StreamEvent.MESSAGE
|
||||
assert chunk1_response.answer == "Chunk 1"
|
||||
assert chunk2_response.answer == "Chunk 2"
|
||||
|
|
@ -132,3 +132,36 @@ def test_extract_images_from_docx(monkeypatch):
|
|||
# DB interactions should be recorded
|
||||
assert len(db_stub.session.added) == 2
|
||||
assert db_stub.session.committed is True
|
||||
|
||||
|
||||
def test_extract_images_from_docx_uses_internal_files_url():
|
||||
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
|
||||
# Test the URL generation logic directly
|
||||
from configs import dify_config
|
||||
|
||||
# Mock the configuration values
|
||||
original_files_url = getattr(dify_config, "FILES_URL", None)
|
||||
original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None)
|
||||
|
||||
try:
|
||||
# Set both URLs - INTERNAL should take precedence
|
||||
dify_config.FILES_URL = "http://external.example.com"
|
||||
dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001"
|
||||
|
||||
# Test the URL generation logic (same as in word_extractor.py)
|
||||
upload_file_id = "test_file_id"
|
||||
|
||||
# This is the pattern we fixed in the word extractor
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
generated_url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
|
||||
# Verify that INTERNAL_FILES_URL is used instead of FILES_URL
|
||||
assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}"
|
||||
assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}"
|
||||
|
||||
finally:
|
||||
# Restore original values
|
||||
if original_files_url is not None:
|
||||
dify_config.FILES_URL = original_files_url
|
||||
if original_internal_files_url is not None:
|
||||
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||
|
|
|
|||
|
|
@ -901,6 +901,13 @@ class TestFixedRecursiveCharacterTextSplitter:
|
|||
# Verify no empty chunks
|
||||
assert all(len(chunk) > 0 for chunk in result)
|
||||
|
||||
def test_double_slash_n(self):
|
||||
data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2."
|
||||
separator = "\\n\\n---\\n\\n"
|
||||
splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator)
|
||||
chunks = splitter.split_text(data)
|
||||
assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Metadata Preservation
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import pytest
|
||||
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNodeAuthorization,
|
||||
|
|
@ -5,6 +7,7 @@ from core.workflow.nodes.http_request import (
|
|||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
@ -348,3 +351,127 @@ def test_init_params():
|
|||
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
|
||||
executor._init_params()
|
||||
assert executor.params == [("key1", "value1"), ("key2", "value2")]
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_bearer():
|
||||
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": ""},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_basic():
|
||||
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "basic", "api_key": ""},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_api_key_raises_error_custom():
|
||||
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "custom", "api_key": "", "header": "X-Custom-Auth"},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_whitespace_only_api_key_raises_error():
|
||||
"""Test that whitespace-only API key raises AuthorizationConfigError."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": " "},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
|
||||
def test_valid_api_key_works():
|
||||
"""Test that valid API key works correctly for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
url="http://example.com",
|
||||
headers="",
|
||||
params="",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={"type": "bearer", "api_key": "valid-api-key-123"},
|
||||
),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Should not raise an error
|
||||
headers = executor._assembling_headers()
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"] == "Bearer valid-api-key-123"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
|
@ -46,14 +47,16 @@ def make_start_node(user_inputs, variables):
|
|||
|
||||
|
||||
def test_json_object_valid_schema():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
|
|
@ -65,7 +68,7 @@ def test_json_object_valid_schema():
|
|||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
|
||||
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
result = node._run()
|
||||
|
|
@ -74,12 +77,23 @@ def test_json_object_valid_schema():
|
|||
|
||||
|
||||
def test_json_object_invalid_json_string():
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
json_schema=schema,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -88,38 +102,21 @@ def test_json_object_invalid_json_string():
|
|||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
node._run()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
|
||||
def test_json_object_valid_json_but_not_object(value):
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": value}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
|
||||
node._run()
|
||||
|
||||
|
||||
def test_json_object_does_not_match_schema():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
|
|
@ -132,7 +129,7 @@ def test_json_object_does_not_match_schema():
|
|||
]
|
||||
|
||||
# age is a string, which violates the schema (expects number)
|
||||
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
|
||||
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
|
|
@ -141,14 +138,16 @@ def test_json_object_does_not_match_schema():
|
|||
|
||||
|
||||
def test_json_object_missing_required_schema_field():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
|
|
@ -161,7 +160,7 @@ def test_json_object_missing_required_schema_field():
|
|||
]
|
||||
|
||||
# Missing required field "name"
|
||||
user_inputs = {"profile": {"age": 20}}
|
||||
user_inputs = {"profile": json.dumps({"age": 20})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
|
|
@ -214,7 +213,7 @@ def test_json_object_optional_variable_not_provided():
|
|||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=False,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -223,5 +222,5 @@ def test_json_object_optional_variable_not_provided():
|
|||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
# Current implementation raises a validation error even when the variable is optional
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
with pytest.raises(ValueError, match="profile is required in input form"):
|
||||
node._run()
|
||||
|
|
|
|||
|
|
@ -1156,6 +1156,199 @@ class TestBillingServiceEdgeCases:
|
|||
assert "Only team owner or team admin can perform this action" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestBillingServiceSubscriptionOperations:
|
||||
"""Unit tests for subscription operations in BillingService.
|
||||
|
||||
Tests cover:
|
||||
- Bulk plan retrieval with chunking
|
||||
- Expired subscription cleanup whitelist retrieval
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_send_request(self):
|
||||
"""Mock _send_request method."""
|
||||
with patch.object(BillingService, "_send_request") as mock:
|
||||
yield mock
|
||||
|
||||
def test_get_plan_bulk_with_empty_list(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with empty tenant list."""
|
||||
# Arrange
|
||||
tenant_ids = []
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
mock_send_request.assert_not_called()
|
||||
|
||||
def test_get_plan_bulk_with_chunking(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with more than 200 tenants (chunking logic)."""
|
||||
# Arrange - 250 tenants to test chunking (chunk_size = 200)
|
||||
tenant_ids = [f"tenant-{i}" for i in range(250)]
|
||||
|
||||
# First chunk: tenants 0-199
|
||||
first_chunk_response = {
|
||||
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
|
||||
}
|
||||
|
||||
# Second chunk: tenants 200-249
|
||||
second_chunk_response = {
|
||||
"data": {f"tenant-{i}": {"plan": "professional", "expiration_date": 1767225600} for i in range(200, 250)}
|
||||
}
|
||||
|
||||
mock_send_request.side_effect = [first_chunk_response, second_chunk_response]
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 250
|
||||
assert result["tenant-0"]["plan"] == "sandbox"
|
||||
assert result["tenant-199"]["plan"] == "sandbox"
|
||||
assert result["tenant-200"]["plan"] == "professional"
|
||||
assert result["tenant-249"]["plan"] == "professional"
|
||||
assert mock_send_request.call_count == 2
|
||||
|
||||
# Verify first chunk call
|
||||
first_call = mock_send_request.call_args_list[0]
|
||||
assert first_call[0][0] == "POST"
|
||||
assert first_call[0][1] == "/subscription/plan/batch"
|
||||
assert len(first_call[1]["json"]["tenant_ids"]) == 200
|
||||
|
||||
# Verify second chunk call
|
||||
second_call = mock_send_request.call_args_list[1]
|
||||
assert len(second_call[1]["json"]["tenant_ids"]) == 50
|
||||
|
||||
def test_get_plan_bulk_with_partial_batch_failure(self, mock_send_request):
|
||||
"""Test bulk plan retrieval when one batch fails but others succeed."""
|
||||
# Arrange - 250 tenants, second batch will fail
|
||||
tenant_ids = [f"tenant-{i}" for i in range(250)]
|
||||
|
||||
# First chunk succeeds
|
||||
first_chunk_response = {
|
||||
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
|
||||
}
|
||||
|
||||
# Second chunk fails - need to create a mock that raises when called
|
||||
def side_effect_func(*args, **kwargs):
|
||||
if mock_send_request.call_count == 1:
|
||||
return first_chunk_response
|
||||
else:
|
||||
raise ValueError("API error")
|
||||
|
||||
mock_send_request.side_effect = side_effect_func
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert - should only have data from first batch
|
||||
assert len(result) == 200
|
||||
assert result["tenant-0"]["plan"] == "sandbox"
|
||||
assert result["tenant-199"]["plan"] == "sandbox"
|
||||
assert "tenant-200" not in result
|
||||
assert mock_send_request.call_count == 2
|
||||
|
||||
def test_get_plan_bulk_with_all_batches_failing(self, mock_send_request):
|
||||
"""Test bulk plan retrieval when all batches fail."""
|
||||
# Arrange
|
||||
tenant_ids = [f"tenant-{i}" for i in range(250)]
|
||||
|
||||
# All chunks fail
|
||||
def side_effect_func(*args, **kwargs):
|
||||
raise ValueError("API error")
|
||||
|
||||
mock_send_request.side_effect = side_effect_func
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert - should return empty dict
|
||||
assert result == {}
|
||||
assert mock_send_request.call_count == 2
|
||||
|
||||
def test_get_plan_bulk_with_exactly_200_tenants(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with exactly 200 tenants (boundary condition)."""
|
||||
# Arrange
|
||||
tenant_ids = [f"tenant-{i}" for i in range(200)]
|
||||
mock_send_request.return_value = {
|
||||
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
|
||||
}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 200
|
||||
assert mock_send_request.call_count == 1
|
||||
|
||||
def test_get_plan_bulk_with_empty_data_response(self, mock_send_request):
|
||||
"""Test bulk plan retrieval with empty data in response."""
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2"]
|
||||
mock_send_request.return_value = {"data": {}}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request):
|
||||
"""Test successful retrieval of expired subscription cleanup whitelist."""
|
||||
# Arrange
|
||||
api_response = [
|
||||
{
|
||||
"created_at": "2025-10-16T01:56:17",
|
||||
"tenant_id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6",
|
||||
"contact": "example@dify.ai",
|
||||
"id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe5",
|
||||
"expired_at": "2026-01-01T01:56:17",
|
||||
"updated_at": "2025-10-16T01:56:17",
|
||||
},
|
||||
{
|
||||
"created_at": "2025-10-16T02:00:00",
|
||||
"tenant_id": "tenant-2",
|
||||
"contact": "test@example.com",
|
||||
"id": "whitelist-id-2",
|
||||
"expired_at": "2026-02-01T00:00:00",
|
||||
"updated_at": "2025-10-16T02:00:00",
|
||||
},
|
||||
{
|
||||
"created_at": "2025-10-16T03:00:00",
|
||||
"tenant_id": "tenant-3",
|
||||
"contact": "another@example.com",
|
||||
"id": "whitelist-id-3",
|
||||
"expired_at": "2026-03-01T00:00:00",
|
||||
"updated_at": "2025-10-16T03:00:00",
|
||||
},
|
||||
]
|
||||
mock_send_request.return_value = {"data": api_response}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_expired_subscription_cleanup_whitelist()
|
||||
|
||||
# Assert - should return only tenant_ids
|
||||
assert result == ["36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", "tenant-2", "tenant-3"]
|
||||
assert len(result) == 3
|
||||
assert result[0] == "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6"
|
||||
assert result[1] == "tenant-2"
|
||||
assert result[2] == "tenant-3"
|
||||
mock_send_request.assert_called_once_with("GET", "/subscription/cleanup/whitelist")
|
||||
|
||||
def test_get_expired_subscription_cleanup_whitelist_empty_list(self, mock_send_request):
|
||||
"""Test retrieval of empty cleanup whitelist."""
|
||||
# Arrange
|
||||
mock_send_request.return_value = {"data": []}
|
||||
|
||||
# Act
|
||||
result = BillingService.get_expired_subscription_cleanup_whitelist()
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestBillingServiceIntegrationScenarios:
|
||||
"""Integration-style tests simulating real-world usage scenarios.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,520 @@
|
|||
"""
|
||||
Unit tests for document indexing sync task.
|
||||
|
||||
This module tests the document indexing sync task functionality including:
|
||||
- Syncing Notion documents when updated
|
||||
- Validating document and data source existence
|
||||
- Credential validation and retrieval
|
||||
- Cleaning old segments before re-indexing
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_id():
|
||||
"""Generate a unique tenant ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
"""Generate a unique dataset ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_id():
|
||||
"""Generate a unique document ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_workspace_id():
|
||||
"""Generate a Notion workspace ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_page_id():
|
||||
"""Generate a Notion page ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def credential_id():
|
||||
"""Generate a credential ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(dataset_id, tenant_id):
|
||||
"""Create a mock Dataset object."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id):
|
||||
"""Create a mock Document object with Notion data source."""
|
||||
doc = Mock(spec=Document)
|
||||
doc.id = document_id
|
||||
doc.dataset_id = dataset_id
|
||||
doc.tenant_id = tenant_id
|
||||
doc.data_source_type = "notion_import"
|
||||
doc.indexing_status = "completed"
|
||||
doc.error = None
|
||||
doc.stopped_at = None
|
||||
doc.processing_started_at = None
|
||||
doc.doc_form = "text_model"
|
||||
doc.data_source_info_dict = {
|
||||
"notion_workspace_id": notion_workspace_id,
|
||||
"notion_page_id": notion_page_id,
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
"credential_id": credential_id,
|
||||
}
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_segments(document_id):
|
||||
"""Create mock DocumentSegment objects."""
|
||||
segments = []
|
||||
for i in range(3):
|
||||
segment = Mock(spec=DocumentSegment)
|
||||
segment.id = str(uuid.uuid4())
|
||||
segment.document_id = document_id
|
||||
segment.index_node_id = f"node-{document_id}-{i}"
|
||||
segments.append(segment)
|
||||
return segments
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_session.scalars.return_value = MagicMock()
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_datasource_provider_service():
|
||||
"""Mock DatasourceProviderService."""
|
||||
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
|
||||
mock_service_class.return_value = mock_service
|
||||
yield mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_notion_extractor():
|
||||
"""Mock NotionExtractor."""
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||
mock_extractor = MagicMock()
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
yield mock_extractor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock IndexProcessorFactory."""
|
||||
with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean = Mock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
yield mock_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_indexing_runner():
|
||||
"""Mock IndexingRunner."""
|
||||
with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class:
|
||||
mock_runner = MagicMock(spec=IndexingRunner)
|
||||
mock_runner.run = Mock()
|
||||
mock_runner_class.return_value = mock_runner
|
||||
yield mock_runner
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for document_indexing_sync_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDocumentIndexingSyncTask:
|
||||
"""Tests for the document_indexing_sync_task function."""
|
||||
|
||||
def test_document_not_found(self, mock_db_session, dataset_id, document_id):
|
||||
"""Test that task handles document not found gracefully."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_workspace_id is missing."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_page_id is missing."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when data_source_info is empty."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = None
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_credential_not_found(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles missing credentials by updating document status."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_datasource_provider_service.get_datasource_credentials.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
assert mock_document.indexing_status == "error"
|
||||
assert "Datasource credential not found" in mock_document.error
|
||||
assert mock_document.stopped_at is not None
|
||||
mock_db_session.commit.assert_called()
|
||||
mock_db_session.close.assert_called()
|
||||
|
||||
def test_page_not_updated(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task does nothing when page has not been updated."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
# Return same time as stored in document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Document status should remain unchanged
|
||||
assert mock_document.indexing_status == "completed"
|
||||
# No session operations should be performed beyond the initial query
|
||||
mock_db_session.close.assert_not_called()
|
||||
|
||||
def test_successful_sync_when_page_updated(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test successful sync flow when Notion page has been updated."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
# NotionExtractor returns updated time
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Verify document status was updated to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
assert mock_document.processing_started_at is not None
|
||||
|
||||
# Verify segments were cleaned
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify segments were deleted from database
|
||||
for segment in mock_document_segments:
|
||||
mock_db_session.delete.assert_any_call(segment)
|
||||
|
||||
# Verify indexing runner was called
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
|
||||
# Verify session operations
|
||||
assert mock_db_session.commit.called
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_dataset_not_found_during_cleaning(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles dataset not found during cleaning phase."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Document should still be set to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
# Session should be closed after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_cleaning_error_continues_to_indexing(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that indexing continues even if cleaning fails."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Indexing should still be attempted despite cleaning error
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_indexing_runner_document_paused_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that DocumentIsPausedError is handled gracefully."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Session should be closed after handling error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_indexing_runner_general_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that general exceptions during indexing are handled."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Session should be closed after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_notion_extractor_initialized_with_correct_params(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
notion_workspace_id,
|
||||
notion_page_id,
|
||||
):
|
||||
"""Test that NotionExtractor is initialized with correct parameters."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update
|
||||
|
||||
# Act
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||
mock_extractor = MagicMock()
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_extractor_class.assert_called_once_with(
|
||||
notion_workspace_id=notion_workspace_id,
|
||||
notion_obj_id=notion_page_id,
|
||||
notion_page_type="page",
|
||||
notion_access_token="test_token",
|
||||
tenant_id=mock_document.tenant_id,
|
||||
)
|
||||
|
||||
def test_datasource_credentials_requested_correctly(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
credential_id,
|
||||
):
|
||||
"""Test that datasource credentials are requested with correct parameters."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id=mock_document.tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
def test_credential_id_missing_uses_none(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles missing credential_id by passing None."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {
|
||||
"notion_workspace_id": "ws123",
|
||||
"notion_page_id": "page123",
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id=mock_document.tenant_id,
|
||||
credential_id=None,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
def test_index_processor_clean_called_with_correct_params(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that index processor clean is called with correct parameters."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
expected_node_ids = [seg.index_node_id for seg in mock_document_segments]
|
||||
mock_processor.clean.assert_called_once_with(
|
||||
mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True
|
||||
)
|
||||
|
|
@ -37,6 +37,7 @@ show_help() {
|
|||
echo " pipeline - Standard pipeline tasks"
|
||||
echo " triggered_workflow_dispatcher - Trigger dispatcher tasks"
|
||||
echo " trigger_refresh_executor - Trigger refresh tasks"
|
||||
echo " retention - Retention tasks"
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
|
|
@ -105,10 +106,10 @@ if [[ -z "${QUEUES}" ]]; then
|
|||
# Configure queues based on edition
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset and workflow have separate queues
|
||||
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
fi
|
||||
|
||||
echo "No queues specified, using edition-based defaults: ${QUEUES}"
|
||||
|
|
|
|||
|
|
@ -1369,7 +1369,10 @@ PLUGIN_STDIO_BUFFER_SIZE=1024
|
|||
PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880
|
||||
|
||||
PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120
|
||||
# Plugin Daemon side timeout (configure to match the API side below)
|
||||
PLUGIN_MAX_EXECUTION_TIMEOUT=600
|
||||
# API side timeout (configure to match the Plugin Daemon side above)
|
||||
PLUGIN_DAEMON_TIMEOUT=600.0
|
||||
# PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
PIP_MIRROR_URL=
|
||||
|
||||
|
|
@ -1479,4 +1482,9 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
|
|||
ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||
|
||||
# The API key of amplitude
|
||||
AMPLITUDE_API_KEY=
|
||||
AMPLITUDE_API_KEY=
|
||||
|
||||
# Sandbox expired records clean configuration
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ services:
|
|||
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
|
||||
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
|
||||
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
|
||||
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
||||
depends_on:
|
||||
init_permissions:
|
||||
|
|
|
|||
|
|
@ -591,6 +591,7 @@ x-shared-env: &shared-api-worker-env
|
|||
PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880}
|
||||
PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120}
|
||||
PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600}
|
||||
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
|
||||
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
|
||||
PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local}
|
||||
PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage}
|
||||
|
|
@ -663,6 +664,9 @@ x-shared-env: &shared-api-worker-env
|
|||
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:-20}
|
||||
ANNOTATION_IMPORT_MAX_CONCURRENT: ${ANNOTATION_IMPORT_MAX_CONCURRENT:-5}
|
||||
AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-}
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21}
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000}
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30}
|
||||
|
||||
services:
|
||||
# Init container to fix permissions
|
||||
|
|
@ -699,6 +703,7 @@ services:
|
|||
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
|
||||
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
|
||||
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
|
||||
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
||||
depends_on:
|
||||
init_permissions:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Mock for ky HTTP client
|
||||
* This mock is used to avoid ESM issues in Jest tests
|
||||
*/
|
||||
|
||||
type KyResponse = {
|
||||
ok: boolean
|
||||
status: number
|
||||
statusText: string
|
||||
headers: Headers
|
||||
json: jest.Mock
|
||||
text: jest.Mock
|
||||
blob: jest.Mock
|
||||
arrayBuffer: jest.Mock
|
||||
clone: jest.Mock
|
||||
}
|
||||
|
||||
type KyInstance = jest.Mock & {
|
||||
get: jest.Mock
|
||||
post: jest.Mock
|
||||
put: jest.Mock
|
||||
patch: jest.Mock
|
||||
delete: jest.Mock
|
||||
head: jest.Mock
|
||||
create: jest.Mock
|
||||
extend: jest.Mock
|
||||
stop: symbol
|
||||
}
|
||||
|
||||
const createResponse = (data: unknown = {}, status = 200): KyResponse => {
|
||||
const response: KyResponse = {
|
||||
ok: status >= 200 && status < 300,
|
||||
status,
|
||||
statusText: status === 200 ? 'OK' : 'Error',
|
||||
headers: new Headers(),
|
||||
json: jest.fn().mockResolvedValue(data),
|
||||
text: jest.fn().mockResolvedValue(JSON.stringify(data)),
|
||||
blob: jest.fn().mockResolvedValue(new Blob()),
|
||||
arrayBuffer: jest.fn().mockResolvedValue(new ArrayBuffer(0)),
|
||||
clone: jest.fn(),
|
||||
}
|
||||
// Ensure clone returns a new response-like object, not the same instance
|
||||
response.clone.mockImplementation(() => createResponse(data, status))
|
||||
return response
|
||||
}
|
||||
|
||||
const createKyInstance = (): KyInstance => {
|
||||
const instance = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) as KyInstance
|
||||
|
||||
// HTTP methods
|
||||
instance.get = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
instance.post = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
instance.put = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
instance.patch = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
instance.delete = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
instance.head = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
|
||||
// Create new instance with custom options
|
||||
instance.create = jest.fn().mockImplementation(() => createKyInstance())
|
||||
instance.extend = jest.fn().mockImplementation(() => createKyInstance())
|
||||
|
||||
// Stop method for AbortController
|
||||
instance.stop = Symbol('stop')
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
const ky = createKyInstance()
|
||||
|
||||
export default ky
|
||||
export { ky }
|
||||
|
|
@ -16,7 +16,7 @@ import {
|
|||
import { useTranslation } from 'react-i18next'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import s from './style.module.css'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useStore } from '@/app/components/app/store'
|
||||
import AppSideBar from '@/app/components/app-sidebar'
|
||||
import type { NavIcon } from '@/app/components/app-sidebar/navLink'
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import { RiCalendarLine } from '@remixicon/react'
|
|||
import type { Dayjs } from 'dayjs'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { formatToLocalTime } from '@/utils/format'
|
||||
import { useI18N } from '@/context/i18n'
|
||||
import Picker from '@/app/components/base/date-and-time-picker/date-picker'
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import { SimpleSelect } from '@/app/components/base/select'
|
|||
import type { Item } from '@/app/components/base/select'
|
||||
import dayjs from 'dayjs'
|
||||
import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const today = dayjs()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import React, { useCallback, useRef, useState } from 'react'
|
|||
|
||||
import type { PopupProps } from './config-popup'
|
||||
import ConfigPopup from './config-popup'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import Indicator from '@/app/components/header/indicator'
|
|||
import Switch from '@/app/components/base/switch'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
const I18N_PREFIX = 'app.tracing'
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Input from '@/app/components/base/input'
|
||||
|
||||
type Props = {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangS
|
|||
import { TracingProvider } from './type'
|
||||
import TracingIcon from './tracing-icon'
|
||||
import ConfigButton from './config-button'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps'
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import {
|
|||
} from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { TracingProvider } from './type'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
|
||||
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { TracingIcon as Icon } from '@/app/components/base/icons/src/public/tracing'
|
||||
|
||||
type Props = {
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use
|
|||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import ExtraInfo from '@/app/components/datasets/extra-info'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
export type IAppDetailLayoutProps = {
|
||||
children: React.ReactNode
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
'use client'
|
||||
import Header from '@/app/signin/_header'
|
||||
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
export default function SignInLayout({ children }: any) {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import cn from 'classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { RiCheckboxCircleFill } from '@remixicon/react'
|
||||
import { useCountDown } from 'ahooks'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
'use client'
|
||||
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import type { PropsWithChildren } from 'react'
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import Loading from '@/app/components/base/loading'
|
|||
import MailAndCodeAuth from './components/mail-and-code-auth'
|
||||
import MailAndPasswordAuth from './components/mail-and-password-auth'
|
||||
import SSOAuth from './components/sso-auth'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { LicenseStatus } from '@/types/feature'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
'use client'
|
||||
import Header from '@/app/signin/_header'
|
||||
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
'use client'
|
||||
import { useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import useSWR from 'swr'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
||||
import { invitationCheck } from '@/service/common'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { useInvitationCheck } from '@/service/use-common'
|
||||
|
||||
const ActivateForm = () => {
|
||||
useDocumentTitle('')
|
||||
|
|
@ -26,19 +26,21 @@ const ActivateForm = () => {
|
|||
token,
|
||||
},
|
||||
}
|
||||
const { data: checkRes } = useSWR(checkParams, invitationCheck, {
|
||||
revalidateOnFocus: false,
|
||||
onSuccess(data) {
|
||||
if (data.is_valid) {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
const { email, workspace_id } = data.data
|
||||
params.set('email', encodeURIComponent(email))
|
||||
params.set('workspace_id', encodeURIComponent(workspace_id))
|
||||
params.set('invite_token', encodeURIComponent(token as string))
|
||||
router.replace(`/signin?${params.toString()}`)
|
||||
}
|
||||
},
|
||||
})
|
||||
const { data: checkRes } = useInvitationCheck({
|
||||
...checkParams.params,
|
||||
token: token || undefined,
|
||||
}, true)
|
||||
|
||||
useEffect(() => {
|
||||
if (checkRes?.is_valid) {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
const { email, workspace_id } = checkRes.data
|
||||
params.set('email', encodeURIComponent(email))
|
||||
params.set('workspace_id', encodeURIComponent(workspace_id))
|
||||
params.set('invite_token', encodeURIComponent(token as string))
|
||||
router.replace(`/signin?${params.toString()}`)
|
||||
}
|
||||
}, [checkRes, router, searchParams, token])
|
||||
|
||||
return (
|
||||
<div className={
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import React from 'react'
|
||||
import Header from '../signin/_header'
|
||||
import ActivateForm from './activateForm'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
const Activate = () => {
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie
|
|||
import type { Operation } from './app-operations'
|
||||
import AppOperations from './app-operations'
|
||||
import dynamic from 'next/dynamic'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import AppInfo from './app-info'
|
|||
import NavLink from './navLink'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import type { NavIcon } from './navLink'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
type Props = {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import React, { useCallback, useState } from 'react'
|
|||
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem'
|
||||
import ActionButton from '../../base/action-button'
|
||||
import { RiMoreFill } from '@remixicon/react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Menu from './menu'
|
||||
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
|
|
|
|||
|
|
@ -0,0 +1,379 @@
|
|||
import React from 'react'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import DatasetInfo from './index'
|
||||
import Dropdown from './dropdown'
|
||||
import Menu from './menu'
|
||||
import MenuItem from './menu-item'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import {
|
||||
ChunkingMode,
|
||||
DataSourceType,
|
||||
DatasetPermission,
|
||||
} from '@/models/datasets'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import { RiEditLine } from '@remixicon/react'
|
||||
|
||||
let mockDataset: DataSet
|
||||
let mockIsDatasetOperator = false
|
||||
const mockReplace = jest.fn()
|
||||
const mockInvalidDatasetList = jest.fn()
|
||||
const mockInvalidDatasetDetail = jest.fn()
|
||||
const mockExportPipeline = jest.fn()
|
||||
const mockCheckIsUsedInApp = jest.fn()
|
||||
const mockDeleteDataset = jest.fn()
|
||||
|
||||
const createDataset = (overrides: Partial<DataSet> = {}): DataSet => ({
|
||||
id: 'dataset-1',
|
||||
name: 'Dataset Name',
|
||||
indexing_status: 'completed',
|
||||
icon_info: {
|
||||
icon: '📙',
|
||||
icon_background: '#FFF4ED',
|
||||
icon_type: 'emoji',
|
||||
icon_url: '',
|
||||
},
|
||||
description: 'Dataset description',
|
||||
permission: DatasetPermission.onlyMe,
|
||||
data_source_type: DataSourceType.FILE,
|
||||
indexing_technique: 'high_quality' as DataSet['indexing_technique'],
|
||||
created_by: 'user-1',
|
||||
updated_by: 'user-1',
|
||||
updated_at: 1690000000,
|
||||
app_count: 0,
|
||||
doc_form: ChunkingMode.text,
|
||||
document_count: 1,
|
||||
total_document_count: 1,
|
||||
word_count: 1000,
|
||||
provider: 'internal',
|
||||
embedding_model: 'text-embedding-3',
|
||||
embedding_model_provider: 'openai',
|
||||
embedding_available: true,
|
||||
retrieval_model_dict: {
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
},
|
||||
top_k: 5,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0,
|
||||
},
|
||||
retrieval_model: {
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
},
|
||||
top_k: 5,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0,
|
||||
},
|
||||
tags: [],
|
||||
external_knowledge_info: {
|
||||
external_knowledge_id: '',
|
||||
external_knowledge_api_id: '',
|
||||
external_knowledge_api_name: '',
|
||||
external_knowledge_api_endpoint: '',
|
||||
},
|
||||
external_retrieval_model: {
|
||||
top_k: 0,
|
||||
score_threshold: 0,
|
||||
score_threshold_enabled: false,
|
||||
},
|
||||
built_in_field_enabled: false,
|
||||
runtime_mode: 'rag_pipeline',
|
||||
enable_api: false,
|
||||
is_multimodal: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
jest.mock('next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
replace: mockReplace,
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/context/dataset-detail', () => ({
|
||||
useDatasetDetailContextWithSelector: (selector: (state: { dataset?: DataSet }) => unknown) => selector({ dataset: mockDataset }),
|
||||
}))
|
||||
|
||||
jest.mock('@/context/app-context', () => ({
|
||||
useSelector: (selector: (state: { isCurrentWorkspaceDatasetOperator: boolean }) => unknown) =>
|
||||
selector({ isCurrentWorkspaceDatasetOperator: mockIsDatasetOperator }),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/knowledge/use-dataset', () => ({
|
||||
datasetDetailQueryKeyPrefix: ['dataset', 'detail'],
|
||||
useInvalidDatasetList: () => mockInvalidDatasetList,
|
||||
}))
|
||||
|
||||
jest.mock('@/service/use-base', () => ({
|
||||
useInvalid: () => mockInvalidDatasetDetail,
|
||||
}))
|
||||
|
||||
jest.mock('@/service/use-pipeline', () => ({
|
||||
useExportPipelineDSL: () => ({
|
||||
mutateAsync: mockExportPipeline,
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/datasets', () => ({
|
||||
checkIsUsedInApp: (...args: unknown[]) => mockCheckIsUsedInApp(...args),
|
||||
deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args),
|
||||
}))
|
||||
|
||||
jest.mock('@/hooks/use-knowledge', () => ({
|
||||
useKnowledge: () => ({
|
||||
formatIndexingTechniqueAndMethod: () => 'indexing-technique',
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/datasets/rename-modal', () => ({
|
||||
__esModule: true,
|
||||
default: ({
|
||||
show,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}: {
|
||||
show: boolean
|
||||
onClose: () => void
|
||||
onSuccess?: () => void
|
||||
}) => {
|
||||
if (!show)
|
||||
return null
|
||||
return (
|
||||
<div data-testid="rename-modal">
|
||||
<button type="button" onClick={onSuccess}>Success</button>
|
||||
<button type="button" onClick={onClose}>Close</button>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
const openMenu = async (user: ReturnType<typeof userEvent.setup>) => {
|
||||
const trigger = screen.getByRole('button')
|
||||
await user.click(trigger)
|
||||
}
|
||||
|
||||
describe('DatasetInfo', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
mockDataset = createDataset()
|
||||
mockIsDatasetOperator = false
|
||||
})
|
||||
|
||||
// Rendering of dataset summary details based on expand and dataset state.
|
||||
describe('Rendering', () => {
|
||||
it('should show dataset details when expanded', () => {
|
||||
// Arrange
|
||||
mockDataset = createDataset({ is_published: true })
|
||||
render(<DatasetInfo expand />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('Dataset Name')).toBeInTheDocument()
|
||||
expect(screen.getByText('Dataset description')).toBeInTheDocument()
|
||||
expect(screen.getByText('dataset.chunkingMode.general')).toBeInTheDocument()
|
||||
expect(screen.getByText('indexing-technique')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show external tag when provider is external', () => {
|
||||
// Arrange
|
||||
mockDataset = createDataset({ provider: 'external', is_published: false })
|
||||
render(<DatasetInfo expand />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('dataset.externalTag')).toBeInTheDocument()
|
||||
expect(screen.queryByText('dataset.chunkingMode.general')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide detailed fields when collapsed', () => {
|
||||
// Arrange
|
||||
render(<DatasetInfo expand={false} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByText('Dataset Name')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('Dataset description')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('MenuItem', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
// Event handling for menu item interactions.
|
||||
describe('Interactions', () => {
|
||||
it('should call handler when clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleClick = jest.fn()
|
||||
// Arrange
|
||||
render(<MenuItem name="Edit" Icon={RiEditLine} handleClick={handleClick} />)
|
||||
|
||||
// Act
|
||||
await user.click(screen.getByText('Edit'))
|
||||
|
||||
// Assert
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Menu', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
mockDataset = createDataset()
|
||||
})
|
||||
|
||||
// Rendering of menu options based on runtime mode and delete visibility.
|
||||
describe('Rendering', () => {
|
||||
it('should show edit, export, and delete options when rag pipeline and deletable', () => {
|
||||
// Arrange
|
||||
mockDataset = createDataset({ runtime_mode: 'rag_pipeline' })
|
||||
render(
|
||||
<Menu
|
||||
showDelete
|
||||
openRenameModal={jest.fn()}
|
||||
handleExportPipeline={jest.fn()}
|
||||
detectIsUsedByApp={jest.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
|
||||
expect(screen.getByText('datasetPipeline.operations.exportPipeline')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.operation.delete')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide export and delete options when not rag pipeline and not deletable', () => {
|
||||
// Arrange
|
||||
mockDataset = createDataset({ runtime_mode: 'general' })
|
||||
render(
|
||||
<Menu
|
||||
showDelete={false}
|
||||
openRenameModal={jest.fn()}
|
||||
handleExportPipeline={jest.fn()}
|
||||
detectIsUsedByApp={jest.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
|
||||
expect(screen.queryByText('datasetPipeline.operations.exportPipeline')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Dropdown', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
mockDataset = createDataset({ pipeline_id: 'pipeline-1', runtime_mode: 'rag_pipeline' })
|
||||
mockIsDatasetOperator = false
|
||||
mockExportPipeline.mockResolvedValue({ data: 'pipeline-content' })
|
||||
mockCheckIsUsedInApp.mockResolvedValue({ is_using: false })
|
||||
mockDeleteDataset.mockResolvedValue({})
|
||||
if (!('createObjectURL' in URL)) {
|
||||
Object.defineProperty(URL, 'createObjectURL', {
|
||||
value: jest.fn(),
|
||||
writable: true,
|
||||
})
|
||||
}
|
||||
if (!('revokeObjectURL' in URL)) {
|
||||
Object.defineProperty(URL, 'revokeObjectURL', {
|
||||
value: jest.fn(),
|
||||
writable: true,
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Rendering behavior based on workspace role.
|
||||
describe('Rendering', () => {
|
||||
it('should hide delete option when user is dataset operator', async () => {
|
||||
const user = userEvent.setup()
|
||||
// Arrange
|
||||
mockIsDatasetOperator = true
|
||||
render(<Dropdown expand />)
|
||||
|
||||
// Act
|
||||
await openMenu(user)
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// User interactions that trigger modals and exports.
|
||||
describe('Interactions', () => {
|
||||
it('should open rename modal when edit is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
// Arrange
|
||||
render(<Dropdown expand />)
|
||||
|
||||
// Act
|
||||
await openMenu(user)
|
||||
await user.click(screen.getByText('common.operation.edit'))
|
||||
|
||||
// Assert
|
||||
expect(screen.getByTestId('rename-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should export pipeline when export is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const anchorClickSpy = jest.spyOn(HTMLAnchorElement.prototype, 'click')
|
||||
const createObjectURLSpy = jest.spyOn(URL, 'createObjectURL')
|
||||
// Arrange
|
||||
render(<Dropdown expand />)
|
||||
|
||||
// Act
|
||||
await openMenu(user)
|
||||
await user.click(screen.getByText('datasetPipeline.operations.exportPipeline'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockExportPipeline).toHaveBeenCalledWith({
|
||||
pipelineId: 'pipeline-1',
|
||||
include: false,
|
||||
})
|
||||
})
|
||||
expect(createObjectURLSpy).toHaveBeenCalledTimes(1)
|
||||
expect(anchorClickSpy).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should show delete confirmation when delete is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
// Arrange
|
||||
render(<Dropdown expand />)
|
||||
|
||||
// Act
|
||||
await openMenu(user)
|
||||
await user.click(screen.getByText('common.operation.delete'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('dataset.deleteDatasetConfirmContent')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should delete dataset and redirect when confirm is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
// Arrange
|
||||
render(<Dropdown expand />)
|
||||
|
||||
// Act
|
||||
await openMenu(user)
|
||||
await user.click(screen.getByText('common.operation.delete'))
|
||||
await user.click(await screen.findByRole('button', { name: 'common.operation.confirm' }))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockDeleteDataset).toHaveBeenCalledWith('dataset-1')
|
||||
})
|
||||
expect(mockInvalidDatasetList).toHaveBeenCalledTimes(1)
|
||||
expect(mockReplace).toHaveBeenCalledWith('/datasets')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -8,7 +8,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
|||
import type { DataSet } from '@/models/datasets'
|
||||
import { DOC_FORM_TEXT } from '@/models/datasets'
|
||||
import { useKnowledge } from '@/hooks/use-knowledge'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Dropdown from './dropdown'
|
||||
|
||||
type DatasetInfoProps = {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import AppIcon from '../base/app-icon'
|
|||
import Divider from '../base/divider'
|
||||
import NavLink from './navLink'
|
||||
import type { NavIcon } from './navLink'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import Effect from '../base/effect'
|
||||
import Dropdown from './dataset-info/dropdown'
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import AppSidebarDropdown from './app-sidebar-dropdown'
|
|||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Divider from '../base/divider'
|
||||
import { useHover, useKeyPress } from 'ahooks'
|
||||
import ToggleButton from './toggle-button'
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import React from 'react'
|
||||
import { useSelectedLayoutSegment } from 'next/navigation'
|
||||
import Link from 'next/link'
|
||||
import classNames from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import type { RemixiconComponentType } from '@remixicon/react'
|
||||
|
||||
export type NavIcon = React.ComponentType<
|
||||
|
|
@ -42,7 +42,7 @@ const NavLink = ({
|
|||
const NavIcon = isActive ? iconMap.selected : iconMap.normal
|
||||
|
||||
const renderIcon = () => (
|
||||
<div className={classNames(mode !== 'expand' && '-ml-1')}>
|
||||
<div className={cn(mode !== 'expand' && '-ml-1')}>
|
||||
<NavIcon className="h-4 w-4 shrink-0" aria-hidden="true" />
|
||||
</div>
|
||||
)
|
||||
|
|
@ -53,21 +53,17 @@ const NavLink = ({
|
|||
key={name}
|
||||
type='button'
|
||||
disabled
|
||||
className={classNames(
|
||||
'system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover',
|
||||
'pl-3 pr-1',
|
||||
)}
|
||||
className={cn('system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover',
|
||||
'pl-3 pr-1')}
|
||||
title={mode === 'collapse' ? name : ''}
|
||||
aria-disabled
|
||||
>
|
||||
{renderIcon()}
|
||||
<span
|
||||
className={classNames(
|
||||
'overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
|
||||
className={cn('overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
|
||||
mode === 'expand'
|
||||
? 'ml-2 max-w-none opacity-100'
|
||||
: 'ml-0 max-w-0 opacity-0',
|
||||
)}
|
||||
: 'ml-0 max-w-0 opacity-0')}
|
||||
>
|
||||
{name}
|
||||
</span>
|
||||
|
|
@ -79,22 +75,18 @@ const NavLink = ({
|
|||
<Link
|
||||
key={name}
|
||||
href={href}
|
||||
className={classNames(
|
||||
isActive
|
||||
? 'system-sm-semibold border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only'
|
||||
: 'system-sm-medium text-components-menu-item-text hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover',
|
||||
'flex h-8 items-center rounded-lg pl-3 pr-1',
|
||||
)}
|
||||
className={cn(isActive
|
||||
? 'system-sm-semibold border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only'
|
||||
: 'system-sm-medium text-components-menu-item-text hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover',
|
||||
'flex h-8 items-center rounded-lg pl-3 pr-1')}
|
||||
title={mode === 'collapse' ? name : ''}
|
||||
>
|
||||
{renderIcon()}
|
||||
<span
|
||||
className={classNames(
|
||||
'overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
|
||||
className={cn('overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
|
||||
mode === 'expand'
|
||||
? 'ml-2 max-w-none opacity-100'
|
||||
: 'ml-0 max-w-0 opacity-0',
|
||||
)}
|
||||
: 'ml-0 max-w-0 opacity-0')}
|
||||
>
|
||||
{name}
|
||||
</span>
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import React from 'react'
|
||||
import Button from '../base/button'
|
||||
import { RiArrowLeftSLine, RiArrowRightSLine } from '@remixicon/react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Tooltip from '../base/tooltip'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { getKeyboardKeyNameBySystem } from '../workflow/utils'
|
||||
|
|
|
|||
|
|
@ -2,12 +2,6 @@ import React from 'react'
|
|||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import EditItem, { EditItemType } from './index'
|
||||
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('AddAnnotationModal/EditItem', () => {
|
||||
test('should render query inputs with user avatar and placeholder strings', () => {
|
||||
render(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
import React from 'react'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import BatchAction from './batch-action'
|
||||
|
||||
describe('BatchAction', () => {
|
||||
const baseProps = {
|
||||
selectedIds: ['1', '2', '3'],
|
||||
onBatchDelete: jest.fn(),
|
||||
onCancel: jest.fn(),
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should show the selected count and trigger cancel action', () => {
|
||||
render(<BatchAction {...baseProps} className='custom-class' />)
|
||||
|
||||
expect(screen.getByText('3')).toBeInTheDocument()
|
||||
expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
|
||||
|
||||
expect(baseProps.onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should confirm before running batch delete', async () => {
|
||||
const onBatchDelete = jest.fn().mockResolvedValue(undefined)
|
||||
render(<BatchAction {...baseProps} onBatchDelete={onBatchDelete} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.delete' }))
|
||||
await screen.findByText('appAnnotation.list.delete.title')
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getAllByRole('button', { name: 'common.operation.delete' })[1])
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onBatchDelete).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -3,7 +3,7 @@ import { RiDeleteBinLine } from '@remixicon/react'
|
|||
import { useTranslation } from 'react-i18next'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import classNames from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
|
||||
const i18nPrefix = 'appAnnotation.batchAction'
|
||||
|
|
@ -38,7 +38,7 @@ const BatchAction: FC<IBatchActionProps> = ({
|
|||
setIsNotDeleting()
|
||||
}
|
||||
return (
|
||||
<div className={classNames('pointer-events-none flex w-full justify-center', className)}>
|
||||
<div className={cn('pointer-events-none flex w-full justify-center', className)}>
|
||||
<div className='pointer-events-auto flex items-center gap-x-1 rounded-[10px] border border-components-actionbar-border-accent bg-components-actionbar-bg-accent p-1 shadow-xl shadow-shadow-shadow-5 backdrop-blur-[5px]'>
|
||||
<div className='inline-flex items-center gap-x-2 py-1 pl-2 pr-3'>
|
||||
<span className='flex h-5 w-5 items-center justify-center rounded-md bg-text-accent px-1 py-0.5 text-xs font-medium text-text-primary-on-surface'>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,72 @@
|
|||
import React from 'react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import CSVDownload from './csv-downloader'
|
||||
import I18nContext from '@/context/i18n'
|
||||
import { LanguagesSupported } from '@/i18n-config/language'
|
||||
import type { Locale } from '@/i18n-config'
|
||||
|
||||
const downloaderProps: any[] = []
|
||||
|
||||
jest.mock('react-papaparse', () => ({
|
||||
useCSVDownloader: jest.fn(() => ({
|
||||
CSVDownloader: ({ children, ...props }: any) => {
|
||||
downloaderProps.push(props)
|
||||
return <div data-testid="mock-csv-downloader">{children}</div>
|
||||
},
|
||||
Type: { Link: 'link' },
|
||||
})),
|
||||
}))
|
||||
|
||||
const renderWithLocale = (locale: Locale) => {
|
||||
return render(
|
||||
<I18nContext.Provider value={{
|
||||
locale,
|
||||
i18n: {},
|
||||
setLocaleOnClient: jest.fn().mockResolvedValue(undefined),
|
||||
}}
|
||||
>
|
||||
<CSVDownload />
|
||||
</I18nContext.Provider>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('CSVDownload', () => {
|
||||
const englishTemplate = [
|
||||
['question', 'answer'],
|
||||
['question1', 'answer1'],
|
||||
['question2', 'answer2'],
|
||||
]
|
||||
const chineseTemplate = [
|
||||
['问题', '答案'],
|
||||
['问题 1', '答案 1'],
|
||||
['问题 2', '答案 2'],
|
||||
]
|
||||
|
||||
beforeEach(() => {
|
||||
downloaderProps.length = 0
|
||||
})
|
||||
|
||||
it('should render the structure preview and pass English template data by default', () => {
|
||||
renderWithLocale('en-US' as Locale)
|
||||
|
||||
expect(screen.getByText('share.generation.csvStructureTitle')).toBeInTheDocument()
|
||||
expect(screen.getByText('appAnnotation.batchModal.template')).toBeInTheDocument()
|
||||
|
||||
expect(downloaderProps[0]).toMatchObject({
|
||||
filename: 'template-en-US',
|
||||
type: 'link',
|
||||
bom: true,
|
||||
data: englishTemplate,
|
||||
})
|
||||
})
|
||||
|
||||
it('should switch to the Chinese template when locale matches the secondary language', () => {
|
||||
const locale = LanguagesSupported[1] as Locale
|
||||
renderWithLocale(locale)
|
||||
|
||||
expect(downloaderProps[0]).toMatchObject({
|
||||
filename: `template-${locale}`,
|
||||
data: chineseTemplate,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -4,7 +4,7 @@ import React, { useEffect, useRef, useState } from 'react'
|
|||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { RiDeleteBinLine } from '@remixicon/react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
|
|
|||
|
|
@ -0,0 +1,164 @@
|
|||
import React from 'react'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import BatchModal, { ProcessStatus } from './index'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/service/annotation'
|
||||
import type { IBatchModalProps } from './index'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
jest.mock('@/app/components/base/toast', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
notify: jest.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
jest.mock('@/service/annotation', () => ({
|
||||
annotationBatchImport: jest.fn(),
|
||||
checkAnnotationBatchImportProgress: jest.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: jest.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('./csv-downloader', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="csv-downloader-stub" />,
|
||||
}))
|
||||
|
||||
let lastUploadedFile: File | undefined
|
||||
|
||||
jest.mock('./csv-uploader', () => ({
|
||||
__esModule: true,
|
||||
default: ({ file, updateFile }: { file?: File; updateFile: (file?: File) => void }) => (
|
||||
<div>
|
||||
<button
|
||||
data-testid="mock-uploader"
|
||||
onClick={() => {
|
||||
lastUploadedFile = new File(['question,answer'], 'batch.csv', { type: 'text/csv' })
|
||||
updateFile(lastUploadedFile)
|
||||
}}
|
||||
>
|
||||
upload
|
||||
</button>
|
||||
{file && <span data-testid="selected-file">{file.name}</span>}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/billing/annotation-full', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="annotation-full" />,
|
||||
}))
|
||||
|
||||
const mockNotify = Toast.notify as jest.Mock
|
||||
const useProviderContextMock = useProviderContext as jest.Mock
|
||||
const annotationBatchImportMock = annotationBatchImport as jest.Mock
|
||||
const checkAnnotationBatchImportProgressMock = checkAnnotationBatchImportProgress as jest.Mock
|
||||
|
||||
const renderComponent = (props: Partial<IBatchModalProps> = {}) => {
|
||||
const mergedProps: IBatchModalProps = {
|
||||
appId: 'app-id',
|
||||
isShow: true,
|
||||
onCancel: jest.fn(),
|
||||
onAdded: jest.fn(),
|
||||
...props,
|
||||
}
|
||||
return {
|
||||
...render(<BatchModal {...mergedProps} />),
|
||||
props: mergedProps,
|
||||
}
|
||||
}
|
||||
|
||||
describe('BatchModal', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
lastUploadedFile = undefined
|
||||
useProviderContextMock.mockReturnValue({
|
||||
plan: {
|
||||
usage: { annotatedResponse: 0 },
|
||||
total: { annotatedResponse: 10 },
|
||||
},
|
||||
enableBilling: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should disable run action and show billing hint when annotation quota is full', () => {
|
||||
useProviderContextMock.mockReturnValue({
|
||||
plan: {
|
||||
usage: { annotatedResponse: 10 },
|
||||
total: { annotatedResponse: 10 },
|
||||
},
|
||||
enableBilling: true,
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('annotation-full')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should reset uploader state when modal closes and allow manual cancellation', () => {
|
||||
const { rerender, props } = renderComponent()
|
||||
|
||||
fireEvent.click(screen.getByTestId('mock-uploader'))
|
||||
expect(screen.getByTestId('selected-file')).toHaveTextContent('batch.csv')
|
||||
|
||||
rerender(<BatchModal {...props} isShow={false} />)
|
||||
rerender(<BatchModal {...props} isShow />)
|
||||
|
||||
expect(screen.queryByTestId('selected-file')).toBeNull()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' }))
|
||||
expect(props.onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should submit the csv file, poll status, and notify when import completes', async () => {
|
||||
jest.useFakeTimers()
|
||||
const { props } = renderComponent()
|
||||
const fileTrigger = screen.getByTestId('mock-uploader')
|
||||
fireEvent.click(fileTrigger)
|
||||
|
||||
const runButton = screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })
|
||||
expect(runButton).not.toBeDisabled()
|
||||
|
||||
annotationBatchImportMock.mockResolvedValue({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING })
|
||||
checkAnnotationBatchImportProgressMock
|
||||
.mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING })
|
||||
.mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.COMPLETED })
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(runButton)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(annotationBatchImportMock).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
const formData = annotationBatchImportMock.mock.calls[0][0].body as FormData
|
||||
expect(formData.get('file')).toBe(lastUploadedFile)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
jest.runOnlyPendingTimers()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
message: 'appAnnotation.batchModal.completed',
|
||||
})
|
||||
expect(props.onAdded).toHaveBeenCalledTimes(1)
|
||||
expect(props.onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
jest.useRealTimers()
|
||||
})
|
||||
})
|
||||
|
|
@ -245,7 +245,7 @@ describe('EditItem', () => {
|
|||
expect(mockSave).toHaveBeenCalledWith('Test save content')
|
||||
})
|
||||
|
||||
it('should show delete option when content changes', async () => {
|
||||
it('should show delete option and restore original content when delete is clicked', async () => {
|
||||
// Arrange
|
||||
const mockSave = jest.fn().mockResolvedValue(undefined)
|
||||
const props = {
|
||||
|
|
@ -267,7 +267,13 @@ describe('EditItem', () => {
|
|||
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
|
||||
// Assert
|
||||
expect(mockSave).toHaveBeenCalledWith('Modified content')
|
||||
expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content')
|
||||
expect(await screen.findByText('common.operation.delete')).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByText('common.operation.delete'))
|
||||
|
||||
expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content')
|
||||
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle keyboard interactions in edit mode', async () => {
|
||||
|
|
@ -393,5 +399,68 @@ describe('EditItem', () => {
|
|||
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('Test content')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle save failure gracefully in edit mode', async () => {
|
||||
// Arrange
|
||||
const mockSave = jest.fn().mockRejectedValueOnce(new Error('Save failed'))
|
||||
const props = {
|
||||
...defaultProps,
|
||||
onSave: mockSave,
|
||||
}
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Act
|
||||
render(<EditItem {...props} />)
|
||||
|
||||
// Enter edit mode and save (should fail)
|
||||
await user.click(screen.getByText('common.operation.edit'))
|
||||
const textarea = screen.getByRole('textbox')
|
||||
await user.type(textarea, 'New content')
|
||||
|
||||
// Save should fail but not throw
|
||||
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
|
||||
// Assert - Should remain in edit mode when save fails
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
|
||||
expect(mockSave).toHaveBeenCalledWith('New content')
|
||||
})
|
||||
|
||||
it('should handle delete action failure gracefully', async () => {
|
||||
// Arrange
|
||||
const mockSave = jest.fn()
|
||||
.mockResolvedValueOnce(undefined) // First save succeeds
|
||||
.mockRejectedValueOnce(new Error('Delete failed')) // Delete fails
|
||||
const props = {
|
||||
...defaultProps,
|
||||
onSave: mockSave,
|
||||
}
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Act
|
||||
render(<EditItem {...props} />)
|
||||
|
||||
// Edit content to show delete button
|
||||
await user.click(screen.getByText('common.operation.edit'))
|
||||
const textarea = screen.getByRole('textbox')
|
||||
await user.clear(textarea)
|
||||
await user.type(textarea, 'Modified content')
|
||||
|
||||
// Save to create new content
|
||||
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
await screen.findByText('common.operation.delete')
|
||||
|
||||
// Click delete (should fail but not throw)
|
||||
await user.click(screen.getByText('common.operation.delete'))
|
||||
|
||||
// Assert - Delete action should handle error gracefully
|
||||
expect(mockSave).toHaveBeenCalledTimes(2)
|
||||
expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content')
|
||||
expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content')
|
||||
|
||||
// When delete fails, the delete button should still be visible (state not changed)
|
||||
expect(screen.getByText('common.operation.delete')).toBeInTheDocument()
|
||||
expect(screen.getByText('Modified content')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import { RiDeleteBinLine, RiEditFill, RiEditLine } from '@remixicon/react'
|
|||
import { Robot, User } from '@/app/components/base/icons/src/public/avatar'
|
||||
import Textarea from '@/app/components/base/textarea'
|
||||
import Button from '@/app/components/base/button'
|
||||
import cn from '@/utils/classnames'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
export enum EditItemType {
|
||||
Query = 'query',
|
||||
|
|
@ -52,8 +52,14 @@ const EditItem: FC<Props> = ({
|
|||
}, [content])
|
||||
|
||||
const handleSave = async () => {
|
||||
await onSave(newContent)
|
||||
setIsEdit(false)
|
||||
try {
|
||||
await onSave(newContent)
|
||||
setIsEdit(false)
|
||||
}
|
||||
catch {
|
||||
// Keep edit mode open when save fails
|
||||
// Error notification is handled by the parent component
|
||||
}
|
||||
}
|
||||
|
||||
const handleCancel = () => {
|
||||
|
|
@ -96,9 +102,16 @@ const EditItem: FC<Props> = ({
|
|||
<div className='mr-2'>·</div>
|
||||
<div
|
||||
className='flex cursor-pointer items-center space-x-1'
|
||||
onClick={() => {
|
||||
setNewContent(content)
|
||||
onSave(content)
|
||||
onClick={async () => {
|
||||
try {
|
||||
await onSave(content)
|
||||
// Only update UI state after successful delete
|
||||
setNewContent(content)
|
||||
}
|
||||
catch {
|
||||
// Delete action failed - error is already handled by parent
|
||||
// UI state remains unchanged, user can retry
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className='h-3.5 w-3.5'>
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue