Merge branch 'main' into feat-agent-mask

This commit is contained in:
GuanMu 2025-12-22 10:57:42 +08:00 committed by GitHub
commit 79c4baa7c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1037 changed files with 54320 additions and 2549 deletions

View File

@ -1,5 +1,5 @@
---
name: Dify Frontend Testing
name: frontend-testing
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
---
@ -178,7 +178,7 @@ Process in this order for multi-file testing:
- **500+ lines**: Consider splitting before testing
- **Many dependencies**: Extract logic into hooks first
> 📖 See `guides/workflow.md` for complete workflow details and todo list format.
> 📖 See `references/workflow.md` for complete workflow details and todo list format.
## Testing Strategy
@ -289,17 +289,18 @@ For each test file generated, aim for:
- ✅ **>95%** branch coverage
- ✅ **>95%** line coverage
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`.
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
## Detailed Guides
For more detailed information, refer to:
- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `guides/mocking.md` - Mock patterns and best practices
- `guides/async-testing.md` - Async operations and API calls
- `guides/domain-components.md` - Workflow, Dataset, Configuration testing
- `guides/common-patterns.md` - Frequently used testing patterns
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `references/mocking.md` - Mock patterns and best practices
- `references/async-testing.md` - Async operations and API calls
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
- `references/common-patterns.md` - Frequently used testing patterns
- `references/checklist.md` - Test generation checklist and validation steps
## Authoritative References

1
.codex/skills Symbolic link
View File

@ -0,0 +1 @@
../.claude/skills

View File

@ -6,7 +6,7 @@ cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc

18
.github/CODEOWNERS vendored
View File

@ -6,6 +6,12 @@
* @crazywoola @laipz8200 @Yeuoly
# CODEOWNERS file
.github/CODEOWNERS @laipz8200 @crazywoola
# Docs
docs/ @crazywoola
# Backend (default owner, more specific rules below will override)
api/ @QuantumGhost
@ -116,11 +122,17 @@ api/controllers/console/feature.py @GarfieldDai @GareArc
api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
api/migrations/ @snakevash @laipz8200
api/migrations/ @snakevash @laipz8200 @MRZHUH
# Backend - Vector DB Middleware
api/configs/middleware/vdb/* @JohnJyong
# Frontend
web/ @iamjoel
# Frontend - Web Tests
.github/workflows/web-tests.yml @iamjoel
# Frontend - App - Orchestration
web/app/components/workflow/ @iamjoel @zxhlyh
web/app/components/workflow-app/ @iamjoel @zxhlyh
@ -192,6 +204,7 @@ web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
web/app/signin/ @douxc @iamjoel
web/app/signup/ @douxc @iamjoel
web/app/reset-password/ @douxc @iamjoel
web/app/install/ @douxc @iamjoel
web/app/init/ @douxc @iamjoel
web/app/forgot-password/ @douxc @iamjoel
@ -232,3 +245,6 @@ web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
# Docker
docker/* @laipz8200

View File

@ -66,7 +66,7 @@ jobs:
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
uvx --python 3.13 mdformat . --exclude ".claude/skills/**"
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
- name: Install pnpm
uses: pnpm/action-setup@v4
@ -79,7 +79,7 @@ jobs:
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
working-directory: ./web

View File

@ -90,7 +90,7 @@ jobs:
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'

View File

@ -55,7 +55,7 @@ jobs:
with:
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: env.FILES_CHANGED == 'true'

View File

@ -13,6 +13,7 @@ jobs:
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: ./web
steps:
@ -21,14 +22,7 @@ jobs:
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: web/**
- name: Install pnpm
if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
@ -36,23 +30,355 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Restore Jest cache
uses: actions/cache@v4
with:
path: web/.cache/jest
key: ${{ runner.os }}-jest-${{ hashFiles('web/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-jest-
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Check i18n types synchronization
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run check:i18n-types
- name: Run tests
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm test
run: |
pnpm exec jest \
--ci \
--maxWorkers=100% \
--coverage \
--passWithNoTests
- name: Coverage Summary
if: always()
id: coverage-summary
run: |
set -eo pipefail
COVERAGE_FILE="coverage/coverage-final.json"
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
exit 0
fi
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
const fs = require('fs');
const path = require('path');
let libCoverage = null;
try {
libCoverage = require('istanbul-lib-coverage');
} catch (error) {
libCoverage = null;
}
const summaryPath = path.join('coverage', 'coverage-summary.json');
const finalPath = path.join('coverage', 'coverage-final.json');
const hasSummary = fs.existsSync(summaryPath);
const hasFinal = fs.existsSync(finalPath);
if (!hasSummary && !hasFinal) {
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('No coverage data found.');
process.exit(0);
}
const summary = hasSummary
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
: null;
const coverage = hasFinal
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
: null;
const getLineCoverageFromStatements = (statementMap, statementHits) => {
const lineHits = {};
if (!statementMap || !statementHits) {
return lineHits;
}
Object.entries(statementMap).forEach(([key, statement]) => {
const line = statement?.start?.line;
if (!line) {
return;
}
const hits = statementHits[key] ?? 0;
const previous = lineHits[line];
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
});
return lineHits;
};
const getFileCoverage = (entry) => (
libCoverage ? libCoverage.createFileCoverage(entry) : null
);
const getLineHits = (entry, fileCoverage) => {
const lineHits = entry.l ?? {};
if (Object.keys(lineHits).length > 0) {
return lineHits;
}
if (fileCoverage) {
return fileCoverage.getLineCoverage();
}
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
};
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
if (lineHits && Object.keys(lineHits).length > 0) {
return Object.entries(lineHits)
.filter(([, count]) => count === 0)
.map(([line]) => Number(line))
.sort((a, b) => a - b);
}
if (fileCoverage) {
return fileCoverage.getUncoveredLines();
}
return [];
};
const totals = {
lines: { covered: 0, total: 0 },
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
};
const fileSummaries = [];
if (summary) {
const totalEntry = summary.total ?? {};
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
if (totalEntry[key]) {
totals[key].covered = totalEntry[key].covered ?? 0;
totals[key].total = totalEntry[key].total ?? 0;
}
});
Object.entries(summary)
.filter(([file]) => file !== 'total')
.forEach(([file, data]) => {
fileSummaries.push({
file,
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
lines: {
covered: data.lines?.covered ?? 0,
total: data.lines?.total ?? 0,
},
});
});
} else if (coverage) {
Object.entries(coverage).forEach(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
totals.lines.total += lineTotal;
totals.lines.covered += lineCovered;
totals.statements.total += statementTotal;
totals.statements.covered += statementCovered;
totals.branches.total += branchTotal;
totals.branches.covered += branchCovered;
totals.functions.total += functionTotal;
totals.functions.covered += functionCovered;
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
fileSummaries.push({
file,
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
lines: {
covered: lineCovered || statementCovered,
total: lineTotal || statementTotal,
},
});
});
}
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('| Metric | Coverage | Covered / Total |');
console.log('|--------|----------|-----------------|');
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
console.log('');
console.log('<details><summary>File coverage (lowest lines first)</summary>');
console.log('');
console.log('```');
fileSummaries
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
.slice(0, 25)
.forEach(({ file, pct, lines }) => {
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
});
console.log('```');
console.log('</details>');
if (coverage) {
const pctValue = (covered, tot) => {
if (tot === 0) {
return '0';
}
return ((covered / tot) * 100)
.toFixed(2)
.replace(/\.?0+$/, '');
};
const formatLineRanges = (lines) => {
if (lines.length === 0) {
return '';
}
const ranges = [];
let start = lines[0];
let end = lines[0];
for (let i = 1; i < lines.length; i += 1) {
const current = lines[i];
if (current === end + 1) {
end = current;
continue;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
start = current;
end = current;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
return ranges.join(',');
};
const tableTotals = {
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
lines: { covered: 0, total: 0 },
};
const tableRows = Object.entries(coverage)
.map(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
tableTotals.lines.total += lineTotal;
tableTotals.lines.covered += lineCovered;
tableTotals.statements.total += statementTotal;
tableTotals.statements.covered += statementCovered;
tableTotals.branches.total += branchTotal;
tableTotals.branches.covered += branchCovered;
tableTotals.functions.total += functionTotal;
tableTotals.functions.covered += functionCovered;
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
const filePath = entry.path ?? file;
const relativePath = path.isAbsolute(filePath)
? path.relative(process.cwd(), filePath)
: filePath;
return {
file: relativePath || file,
statements: pctValue(statementCovered, statementTotal),
branches: pctValue(branchCovered, branchTotal),
functions: pctValue(functionCovered, functionTotal),
lines: pctValue(lineCovered, lineTotal),
uncovered: formatLineRanges(uncoveredLines),
};
})
.sort((a, b) => a.file.localeCompare(b.file));
const columns = [
{ key: 'file', header: 'File', align: 'left' },
{ key: 'statements', header: '% Stmts', align: 'right' },
{ key: 'branches', header: '% Branch', align: 'right' },
{ key: 'functions', header: '% Funcs', align: 'right' },
{ key: 'lines', header: '% Lines', align: 'right' },
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
];
const allFilesRow = {
file: 'All files',
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
uncovered: '',
};
const rowsForOutput = [allFilesRow, ...tableRows];
const formatRow = (row) => `| ${columns
.map(({ key }) => String(row[key] ?? ''))
.join(' | ')} |`;
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
const dividerRow = `| ${columns
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
.join(' | ')} |`;
console.log('');
console.log('<details><summary>Jest coverage table</summary>');
console.log('');
console.log(headerRow);
console.log(dividerRow);
rowsForOutput.forEach((row) => console.log(formatRow(row)));
console.log('</details>');
}
NODE
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v4
with:
name: web-coverage-report
path: web/coverage
retention-days: 30
if-no-files-found: error

View File

@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
"--loglevel",
"INFO"
],

View File

@ -690,3 +690,8 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@ -84,7 +84,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0,
default=600.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
@ -1270,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
)
class SandboxExpiredRecordsCleanConfig(BaseSettings):
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
description="Graceful period in days for sandbox records clean after subscription expiration",
default=21,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
description="Maximum number of records to process in each batch",
default=1000,
)
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1295,6 +1310,7 @@ class FeatureConfig(
PositionConfig,
RagEtlConfig,
RepositoryConfig,
SandboxExpiredRecordsCleanConfig,
SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig,

View File

@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
partial_member_list: list[str] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None

View File

@ -40,7 +40,7 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel):
class CompletionMessageExplorePayload(BaseModel):
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel):
raise ValueError("must be a valid UUID") from exc
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
# define completion api for user
@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
endpoint="installed_app_completion",
)
class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = payload.response_mode == "streaming"

View File

@ -1,5 +1,4 @@
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import marshal_with
@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import UUIDStrOrEmpty
from libs.login import current_user
from models import Account
from models.model import AppMode
@ -24,7 +24,7 @@ from .. import console_ns
class ConversationListQuery(BaseModel):
last_id: UUID | None = None
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None

View File

@ -2,7 +2,8 @@ import logging
from typing import Any
from flask import request
from flask_restx import Resource, inputs, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
@ -18,6 +19,15 @@ from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
class InstalledAppCreatePayload(BaseModel):
app_id: str
class InstalledAppUpdatePayload(BaseModel):
is_pinned: bool | None = None
logger = logging.getLogger(__name__)
@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
def post(self):
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args()
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
if recommended_app is None:
raise NotFound("App not found")
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == args["app_id"]).first()
app = db.session.query(App).where(App.id == payload.app_id).first()
if app is None:
raise NotFound("App not found")
raise NotFound("App entity not found")
if not app.is_public:
raise Forbidden("You can't install a non-public app")
installed_app = (
db.session.query(InstalledApp)
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
.first()
)
@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource):
recommended_app.install_count += 1
new_installed_app = InstalledApp(
app_id=args["app_id"],
app_id=payload.app_id,
tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id,
is_pinned=False,
@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource):
return {"result": "success", "message": "App uninstalled successfully"}, 204
def patch(self, installed_app):
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
args = parser.parse_args()
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
commit_args = False
if "is_pinned" in args:
installed_app.is_pinned = args["is_pinned"]
if payload.is_pinned is not None:
installed_app.is_pinned = payload.is_pinned
commit_args = True
if commit_args:

View File

@ -1,31 +1,40 @@
from typing import Literal
from flask import request
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required
from models.model import Tag
from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 50:
raise ValueError("Name must be between 1 to 50 characters.")
return name
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
parser_tags = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
class TagBindingPayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to bind")
target_id: str = Field(description="Target ID to bind tags to")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
register_schema_models(
console_ns,
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
)
@ -43,7 +52,7 @@ class TagListApi(Resource):
return tags, 200
@console_ns.expect(parser_tags)
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -53,22 +62,17 @@ class TagListApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_tags.parse_args()
tag = TagService.save_tags(args)
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(payload.model_dump())
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
parser_tag_id = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource):
@console_ns.expect(parser_tag_id)
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -79,8 +83,8 @@ class TagUpdateDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_tag_id.parse_args()
tag = TagService.update_tags(args, tag_id)
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(payload.model_dump(), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@ -100,17 +104,9 @@ class TagUpdateDeleteApi(Resource):
return 204
parser_create = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
@console_ns.expect(parser_create)
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -120,23 +116,15 @@ class TagBindingCreateApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_create.parse_args()
TagService.save_tag_binding(args)
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(payload.model_dump())
return {"result": "success"}, 200
parser_remove = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
@console_ns.expect(parser_remove)
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -146,7 +134,7 @@ class TagBindingDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_remove.parse_args()
TagService.delete_tag_binding(args)
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(payload.model_dump())
return {"result": "success"}, 200

View File

@ -18,6 +18,7 @@ from controllers.console.wraps import (
setup_required,
)
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient
@ -944,7 +945,7 @@ class ToolProviderMCPApi(Resource):
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider
# Create provider in transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
@ -960,7 +961,11 @@ class ToolProviderMCPApi(Resource):
configuration=configuration,
authentication=authentication,
)
return jsonable_encoder(result)
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(tenant_id)
return jsonable_encoder(result)
@console_ns.expect(parser_mcp_put)
@setup_required
@ -972,17 +977,23 @@ class ToolProviderMCPApi(Resource):
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
validation_result = None
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
validation_data = None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change(
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
validation_data = service.get_provider_for_url_validation(
tenant_id=current_tenant_id, provider_id=args["provider_id"]
)
# No need to check for errors here, exceptions will be raised directly
# Step 2: Perform URL validation with network I/O OUTSIDE of any database session
# This prevents holding database locks during potentially slow network operations
validation_result = MCPToolManageService.validate_server_url_standalone(
tenant_id=current_tenant_id,
new_server_url=args["server_url"],
validation_data=validation_data,
)
# Step 2: Perform database update in a transaction
# Step 3: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
@ -999,7 +1010,11 @@ class ToolProviderMCPApi(Resource):
authentication=authentication,
validation_result=validation_result,
)
return {"result": "success"}
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)
return {"result": "success"}
@console_ns.expect(parser_mcp_delete)
@setup_required
@ -1012,7 +1027,11 @@ class ToolProviderMCPApi(Resource):
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)
return {"result": "success"}
parser_auth = (

View File

@ -49,7 +49,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: RetrievalModel | None = None
partial_member_list: list[str] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None

View File

@ -1,7 +1,8 @@
import logging
from flask import request
from flask_restx import fields, marshal_with, reqparse
from flask_restx import fields, marshal_with
from pydantic import BaseModel, field_validator
from werkzeug.exceptions import InternalServerError
import services
@ -20,6 +21,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App
from services.audio_service import AudioService
from services.errors.audio import (
@ -29,6 +31,25 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError,
)
from ..common.schema import register_schema_models
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
register_schema_models(web_ns, TextToAudioPayload)
logger = logging.getLogger(__name__)
@ -88,6 +109,7 @@ class AudioApi(WebApiResource):
@web_ns.route("/text-to-audio")
class TextApi(WebApiResource):
@web_ns.expect(web_ns.models[TextToAudioPayload.__name__])
@web_ns.doc("Text to Audio")
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
@web_ns.doc(
@ -102,18 +124,11 @@ class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
"""Convert text to audio"""
try:
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
message_id = payload.message_id
text = payload.text
voice = payload.voice
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)

View File

@ -1,9 +1,11 @@
import logging
from typing import Any, Literal
from flask_restx import reqparse
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
AppUnavailableError,
@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the completion")
query: str = Field(default="", description="Query text for completion")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
response_mode: Literal["blocking", "streaming"] | None = Field(
default=None, description="Response mode: blocking or streaming"
)
retriever_from: str = Field(default="web_app", description="Source of retriever")
class ChatMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the chat")
query: str = Field(description="User query/message")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
response_mode: Literal["blocking", "streaming"] | None = Field(
default=None, description="Response mode: blocking or streaming"
)
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
retriever_from: str = Field(default="web_app", description="Source of retriever")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload)
# define completion api for user
@web_ns.route("/completion-messages")
class CompletionApi(WebApiResource):
@web_ns.doc("Create Completion Message")
@web_ns.doc(description="Create a completion message for text generation applications.")
@web_ns.doc(
params={
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
"query": {"description": "Query text for completion", "type": "string", "required": False},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@web_ns.expect(web_ns.models[CompletionMessagePayload.__name__])
@web_ns.doc(
responses={
200: "Success",
@ -67,18 +88,10 @@ class CompletionApi(WebApiResource):
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
)
payload = CompletionMessagePayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
try:
@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource):
@web_ns.doc("Create Chat Message")
@web_ns.doc(description="Create a chat message for conversational applications.")
@web_ns.doc(
params={
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
"query": {"description": "User query/message", "type": "string", "required": True},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@web_ns.expect(web_ns.models[ChatMessagePayload.__name__])
@web_ns.doc(
responses={
200: "Success",
@ -173,20 +171,10 @@ class ChatApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
)
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
try:

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict[str, Any] | None = Field(default=None)
json_schema: str | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
def validate_json_schema(cls, schema: str | None) -> str | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@ -175,6 +176,13 @@ class BaseAppGenerator:
value = True
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
case _:
raise AssertionError("this statement should be unreachable.")

View File

@ -342,9 +342,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
event_type=event_type,
)
else:
yield self._agent_message_to_stream_response(

View File

@ -5,7 +5,7 @@ from threading import Thread
from typing import Union
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
from configs import dify_config
@ -54,6 +54,20 @@ class MessageCycleManager:
):
self._application_generate_entity = application_generate_entity
self._task_state = task_state
self._message_has_file: set[str] = set()
def get_message_event_type(self, message_id: str) -> StreamEvent:
if message_id in self._message_has_file:
return StreamEvent.MESSAGE_FILE
with Session(db.engine, expire_on_commit=False) as session:
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
if has_file:
self._message_has_file.add(message_id)
return StreamEvent.MESSAGE_FILE
return StreamEvent.MESSAGE
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
"""
@ -214,7 +228,11 @@ class MessageCycleManager:
return None
def message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
self,
answer: str,
message_id: str,
from_variable_selector: list[str] | None = None,
event_type: StreamEvent | None = None,
) -> MessageStreamResponse:
"""
Message to stream response.
@ -222,16 +240,12 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
event=event_type,
event=event_type or StreamEvent.MESSAGE,
)
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:

View File

@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
"""
Build a list of URLs to try for Protected Resource Metadata discovery.
Per SEP-985, supports fallback when discovery fails at one URL.
Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
Priority order:
1. URL from WWW-Authenticate header (if provided)
2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
"""
urls = []
@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
# Fallback: construct from server URL
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
if fallback_url not in urls:
urls.append(fallback_url)
path = parsed.path.rstrip("/")
# Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
if path:
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
if path_url not in urls:
urls.append(path_url)
# Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
root_url = f"{base_url}/.well-known/oauth-protected-resource"
if root_url not in urls:
urls.append(root_url)
return urls
@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
Per RFC 8414 section 3:
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
Example:
- issuer: https://example.com/oauth
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
- OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
- OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
- OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
- OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
- OpenID Connect at root: https://example.com/.well-known/openid-configuration
"""
urls = []
base_url = auth_server_url or server_url
parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/") # Remove trailing slash
path = parsed.path.rstrip("/")
# OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
urls.append(f"{base}/.well-known/oauth-authorization-server")
# Try OpenID Connect discovery first (more common)
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
# OpenID Connect Discovery at root
urls.append(f"{base}/.well-known/openid-configuration")
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
# Include the path component if present in the issuer URL
if path:
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
else:
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
# OpenID Connect Discovery with path insertion
urls.append(f"{base}/.well-known/openid-configuration{path}")
# OpenID Connect Discovery path appending
urls.append(f"{base}{path}/.well-known/openid-configuration")
# OAuth 2.0 Authorization Server Metadata with path insertion
urls.append(f"{base}/.well-known/oauth-authorization-server{path}")
return urls

View File

@ -59,7 +59,7 @@ class MCPClient:
try:
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
self.connect_server(sse_client, "sse")
except MCPConnectionError:
except (MCPConnectionError, ValueError):
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")

View File

@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model
- Model provider display
![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png)
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
- Selectable model list display
![image-20231210144229650](./docs/en_US/images/index/image-20231210144229650.png)
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png)
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
- Provider/model credential authentication
![image-20231210151548521](./docs/en_US/images/index/image-20231210151548521.png)
![image-20231210151628992](./docs/en_US/images/index/image-20231210151628992.png)
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
## Structure
![](./docs/en_US/images/index/image-20231210165243632.png)
Model Runtime is divided into three layers:
- The outermost layer is the factory method
@ -60,9 +46,6 @@ Model Runtime is divided into three layers:
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
## Next Steps
## Documentation
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
- View YAML configuration rules: [Link](./docs/en_US/schema.md)
- Implement interface methods: [Link](./docs/en_US/interfaces.md)
For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).

View File

@ -18,34 +18,20 @@
- 模型供应商展示
![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png)
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
- 可选择的模型列表展示
![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png)
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png)
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
- 供应商/模型凭据鉴权
![image-20231210151548521](./docs/zh_Hans/images/index/image-20231210151548521.png)
![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png)
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO上图 2 为模型凭据 DEMO。
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
## 结构
![](./docs/zh_Hans/images/index/image-20231210165243632.png)
Model Runtime 分三层:
- 最外层为工厂方法
@ -59,8 +45,7 @@ Model Runtime 分三层:
对于供应商/模型凭据,有两种情况
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
![Alt text](docs/zh_Hans/images/index/image.png)
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
@ -74,20 +59,6 @@ Model Runtime 分三层:
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
## 下一步
## 文档
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
当添加后,这里将会出现一个新的供应商
![Alt text](docs/zh_Hans/images/index/image-1.png)
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B)
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
![Alt text](docs/zh_Hans/images/index/image-2.png)
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。

View File

@ -39,7 +39,7 @@ from core.trigger.errors import (
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
_plugin_daemon_timeout_config = cast(
float | httpx.Timeout | None,
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
)
plugin_daemon_request_timeout: httpx.Timeout | None
if _plugin_daemon_timeout_config is None:

View File

@ -83,6 +83,7 @@ class WordExtractor(BaseExtractor):
def _extract_images_from_docx(self, doc):
image_count = 0
image_map = {}
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
for r_id, rel in doc.part.rels.items():
if "image" in rel.target_ref:
@ -121,8 +122,7 @@ class WordExtractor(BaseExtractor):
used_at=naive_utc_now(),
)
db.session.add(upload_file)
# Use r_id as key for external images since target_part is undefined
image_map[r_id] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
image_map[r_id] = f"![image]({base_url}/files/{upload_file.id}/file-preview)"
else:
image_ext = rel.target_ref.split(".")[-1]
if image_ext is None:
@ -150,10 +150,7 @@ class WordExtractor(BaseExtractor):
used_at=naive_utc_now(),
)
db.session.add(upload_file)
# Use target_part as key for internal images
image_map[rel.target_part] = (
f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
)
image_map[rel.target_part] = f"![image]({base_url}/files/{upload_file.id}/file-preview)"
db.session.commit()
return image_map

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import codecs
import re
from typing import Any
@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._fixed_separator = fixed_separator
self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
self._separators = separators or ["\n\n", "\n", "", ". ", " ", ""]
def split_text(self, text: str) -> list[str]:
@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = re.split(r" +", text)
else:
splits = text.split(separator)
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
if self._keep_separator:
splits = [s + separator for s in splits[:-1]] + splits[-1:]
else:
splits = list(text)
if separator == "\n":
@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = [s for s in splits if (s not in {"", "\n"})]
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = separator if self._keep_separator else ""
_separator = "" if self._keep_separator else separator
s_lens = self._length_function(splits)
if separator != "":
for s, s_len in zip(splits, s_lens):

View File

@ -86,6 +86,11 @@ class Executor:
node_data.authorization.config.api_key = variable_pool.convert_template(
node_data.authorization.config.api_key
).text
# Validate that API key is not empty after template conversion
if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip():
raise AuthorizationConfigError(
"API key is required for authorization but was empty. Please provide a valid API key."
)
self.url = node_data.url
self.method = node_data.method

View File

@ -1,3 +1,4 @@
import json
from typing import Any
from jsonschema import Draft7Validator, ValidationError
@ -42,15 +43,25 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
if not isinstance(value, dict):
raise ValueError(f"{key} must be a JSON object")
schema = variable.json_schema
if not schema:
continue
if not value:
continue
try:
Draft7Validator(schema).validate(value)
json_schema = json.loads(schema)
except json.JSONDecodeError as e:
raise ValueError(f"{schema} must be a valid JSON object")
try:
json_value = json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"{value} must be a valid JSON object")
try:
Draft7Validator(json_schema).validate(json_value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
node_inputs[key] = value
node_inputs[key] = json_value

View File

@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"
@ -69,6 +69,53 @@ if [[ "${MODE}" == "worker" ]]; then
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
elif [[ "${MODE}" == "job" ]]; then
# Job mode: Run a one-time Flask command and exit
# Pass Flask command and arguments via container args
# Example K8s usage:
# args:
# - create-tenant
# - --email
# - admin@example.com
#
# Example Docker usage:
# docker run -e MODE=job dify-api:latest create-tenant --email admin@example.com
if [[ $# -eq 0 ]]; then
echo "Error: No command specified for job mode."
echo ""
echo "Usage examples:"
echo " Kubernetes:"
echo " args: [create-tenant, --email, admin@example.com]"
echo ""
echo " Docker:"
echo " docker run -e MODE=job dify-api create-tenant --email admin@example.com"
echo ""
echo "Available commands:"
echo " create-tenant, reset-password, reset-email, upgrade-db,"
echo " vdb-migrate, install-plugins, and more..."
echo ""
echo "Run 'flask --help' to see all available commands."
exit 1
fi
echo "Running Flask job command: flask $*"
# Temporarily disable exit on error to capture exit code
set +e
flask "$@"
JOB_EXIT_CODE=$?
set -e
if [[ ${JOB_EXIT_CODE} -eq 0 ]]; then
echo "Job completed successfully."
else
echo "Job failed with exit code ${JOB_EXIT_CODE}."
fi
exit ${JOB_EXIT_CODE}
else
if [[ "${DEBUG}" == "true" ]]; then
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug

View File

@ -87,15 +87,16 @@ class OpenDALStorage(BaseStorage):
if not self.exists(path):
raise FileNotFoundError("Path not found")
all_files = self.op.scan(path=path)
# Use the new OpenDAL 0.46.0+ API with recursive listing
lister = self.op.list(path, recursive=True)
if files and directories:
logger.debug("files and directories on %s scanned", path)
return [f.path for f in all_files]
return [entry.path for entry in lister]
if files:
logger.debug("files on %s scanned", path)
return [f.path for f in all_files if not f.path.endswith("/")]
return [entry.path for entry in lister if not entry.metadata.is_dir]
elif directories:
logger.debug("directories on %s scanned", path)
return [f.path for f in all_files if f.path.endswith("/")]
return [entry.path for entry in lister if entry.metadata.is_dir]
else:
raise ValueError("At least one of files or directories must be True")

View File

@ -11,6 +11,7 @@ from collections.abc import Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
from uuid import UUID
from zoneinfo import available_timezones
from flask import Response, stream_with_context
@ -119,6 +120,19 @@ def uuid_value(value: Any) -> str:
raise ValueError(error)
def normalize_uuid(value: str | UUID) -> str:
if not value:
return ""
try:
return uuid_value(value)
except ValueError as exc:
raise ValueError("must be a valid UUID") from exc
UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)]
def alphanumeric(value: str):
# check if the value is alphanumeric and underlined
if re.match(r"^[a-zA-Z0-9_]+$", value):

View File

@ -1,8 +1,12 @@
import logging
import os
from collections.abc import Sequence
from typing import Literal
import httpx
from pydantic import TypeAdapter
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from typing_extensions import TypedDict
from werkzeug.exceptions import InternalServerError
from enums.cloud_plan import CloudPlan
@ -11,6 +15,15 @@ from extensions.ext_redis import redis_client
from libs.helper import RateLimiter
from models import Account, TenantAccountJoin, TenantAccountRole
logger = logging.getLogger(__name__)
class SubscriptionPlan(TypedDict):
"""Tenant subscriptionplan information."""
plan: str
expiration_date: int
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
@ -239,3 +252,39 @@ class BillingService:
def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str):
payload = {"account_id": account_id, "click_id": click_id}
return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload)
@classmethod
def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
"""
Bulk fetch billing subscription plan via billing API.
Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request)
Returns:
Mapping of tenant_id -> {plan: str, expiration_date: int}
"""
results: dict[str, SubscriptionPlan] = {}
subscription_adapter = TypeAdapter(SubscriptionPlan)
chunk_size = 200
for i in range(0, len(tenant_ids), chunk_size):
chunk = tenant_ids[i : i + chunk_size]
try:
resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk})
data = resp.get("data", {})
for tenant_id, plan in data.items():
subscription_plan = subscription_adapter.validate_python(plan)
results[tenant_id] = subscription_plan
except Exception:
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
continue
return results
@classmethod
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
data = resp.get("data", [])
tenant_whitelist = []
for item in data:
tenant_whitelist.append(item["tenant_id"])
return tenant_whitelist

View File

@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel):
description: str
icon_info: IconInfo
permission: str
partial_member_list: list[str] | None = None
partial_member_list: list[dict[str, str]] | None = None
yaml_content: str | None = None

View File

@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel):
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
class ProviderUrlValidationData(BaseModel):
"""Data required for URL validation, extracted from database to perform network operations outside of session"""
current_server_url_hash: str
headers: dict[str, str]
timeout: float | None
sse_read_timeout: float | None
class MCPToolManageService:
"""Service class for managing MCP tools and providers."""
@ -166,9 +174,6 @@ class MCPToolManageService:
self._session.add(mcp_tool)
self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers
@ -192,7 +197,7 @@ class MCPToolManageService:
Update an MCP provider.
Args:
validation_result: Pre-validation result from validate_server_url_change.
validation_result: Pre-validation result from validate_server_url_standalone.
If provided and contains reconnect_result, it will be used
instead of performing network operations.
"""
@ -251,8 +256,6 @@ class MCPToolManageService:
# Flush changes to database
self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except IntegrityError as e:
self._handle_integrity_error(e, name, server_url, server_identifier)
@ -261,9 +264,6 @@ class MCPToolManageService:
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
def list_providers(
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
) -> list[ToolProviderApiEntity]:
@ -546,30 +546,39 @@ class MCPToolManageService:
)
return self.execute_auth_actions(auth_result)
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
"""Attempt to reconnect to MCP provider with new server URL."""
def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData:
"""
Get provider data required for URL validation.
This method performs database read and should be called within a session.
Returns:
ProviderUrlValidationData: Data needed for standalone URL validation
"""
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider_entity = provider.to_entity()
headers = provider_entity.headers
return ProviderUrlValidationData(
current_server_url_hash=provider.server_url_hash,
headers=provider_entity.headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
)
try:
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def validate_server_url_change(
self, *, tenant_id: str, provider_id: str, new_server_url: str
@staticmethod
def validate_server_url_standalone(
*,
tenant_id: str,
new_server_url: str,
validation_data: ProviderUrlValidationData,
) -> ServerUrlValidationResult:
"""
Validate server URL change by attempting to connect to the new server.
This method should be called BEFORE update_provider to perform network operations
outside of the database transaction.
This method performs network operations and MUST be called OUTSIDE of any database session
to avoid holding locks during network I/O.
Args:
tenant_id: Tenant ID for encryption
new_server_url: The new server URL to validate
validation_data: Provider data obtained from get_provider_for_url_validation
Returns:
ServerUrlValidationResult: Validation result with connection status and tools if successful
@ -579,25 +588,30 @@ class MCPToolManageService:
return ServerUrlValidationResult(needs_validation=False)
# Validate URL format
if not self._is_valid_url(new_server_url):
parsed = urlparse(new_server_url)
if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]:
raise ValueError("Server URL is not valid.")
# Always encrypt and hash the URL
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
# Get current provider
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Check if URL is actually different
if new_server_url_hash == provider.server_url_hash:
if new_server_url_hash == validation_data.current_server_url_hash:
# URL hasn't changed, but still return the encrypted data
return ServerUrlValidationResult(
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
needs_validation=False,
encrypted_server_url=encrypted_server_url,
server_url_hash=new_server_url_hash,
)
# Perform validation by attempting to connect
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
# Perform network validation - this is the expensive operation that should be outside session
reconnect_result = MCPToolManageService._reconnect_with_url(
server_url=new_server_url,
headers=validation_data.headers,
timeout=validation_data.timeout,
sse_read_timeout=validation_data.sse_read_timeout,
)
return ServerUrlValidationResult(
needs_validation=True,
validation_passed=True,
@ -606,6 +620,38 @@ class MCPToolManageService:
server_url_hash=new_server_url_hash,
)
@staticmethod
def _reconnect_with_url(
*,
server_url: str,
headers: dict[str, str],
timeout: float | None,
sse_read_timeout: float | None,
) -> ReconnectResult:
"""
Attempt to connect to MCP server with given URL.
This is a static method that performs network I/O without database access.
"""
from core.mcp.mcp_client import MCPClient
try:
with MCPClient(
server_url=server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
) -> ToolProviderApiEntity:

View File

@ -2,7 +2,6 @@ import logging
import time
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import select
@ -12,7 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.source import DataSourceOauthBinding
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
@ -48,27 +47,36 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
sa.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
.first()
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
document.tenant_id,
credential_id,
)
document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
db.session.commit()
db.session.close()
return
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=data_source_binding.access_token,
notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id,
)

View File

@ -6,6 +6,7 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.node_factory import DifyNodeFactory
@ -169,13 +170,14 @@ def test_custom_authorization_header(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised."""
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
@ -208,16 +210,13 @@ def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
ssl_verify=True,
)
# Create executor
executor = Executor(
node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool
)
# Get assembled headers
headers = executor._assembling_headers()
# When api_key is empty, the custom header should NOT be set
assert "X-Custom-Auth" not in headers
# Create executor should raise AuthorizationConfigError
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10),
variable_pool=variable_pool,
)
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
@ -305,9 +304,10 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_authorization_with_empty_api_key(setup_http_mock):
"""
Test that custom authorization doesn't set header when api_key is empty.
This test verifies the fix for issue #23554.
Test that custom authorization raises error when api_key is empty.
This test verifies the fix for issue #21830.
"""
node = init_http_node(
config={
"id": "1",
@ -333,11 +333,10 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock):
)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
# Custom header should NOT be set when api_key is empty
assert "X-Custom-Auth:" not in data
# Should fail with AuthorizationConfigError
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "API key is required" in result.error
assert result.error_type == "AuthorizationConfigError"
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)

View File

@ -2,7 +2,9 @@ from unittest.mock import patch
import pytest
from faker import Faker
from pydantic import TypeAdapter, ValidationError
from core.tools.entities.tool_entities import ApiProviderSchemaType
from models import Account, Tenant
from models.tools import ApiToolProvider
from services.tools.api_tools_manage_service import ApiToolManageService
@ -298,7 +300,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
schema_type = "openapi"
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
@ -364,7 +366,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none"}
schema_type = "openapi"
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
@ -428,21 +430,10 @@ class TestApiToolManageService:
labels = ["test"]
# Act & Assert: Try to create provider with invalid schema type
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
with pytest.raises(ValidationError) as exc_info:
TypeAdapter(ApiProviderSchemaType).validate_python(schema_type)
assert "invalid schema type" in str(exc_info.value)
assert "validation error" in str(exc_info.value)
def test_create_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
@ -464,7 +455,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {} # Missing auth_type
schema_type = "openapi"
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
@ -507,7 +498,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔑"}
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
schema_type = "openapi"
schema_type = ApiProviderSchemaType.OPENAPI
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"

View File

@ -1308,18 +1308,17 @@ class TestMCPToolManageService:
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
]
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
# Setup mock client
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
result = MCPToolManageService._reconnect_with_url(
server_url="https://example.com/mcp",
provider=mcp_provider,
headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)
# Assert: Verify the expected outcomes
@ -1337,8 +1336,12 @@ class TestMCPToolManageService:
assert tools_data[1]["name"] == "test_tool_2"
# Verify mock interactions
provider_entity = mcp_provider.to_entity()
mock_mcp_client.assert_called_once()
mock_mcp_client.assert_called_once_with(
server_url="https://example.com/mcp",
headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -1361,19 +1364,18 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise authentication error
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
from core.mcp.error import MCPAuthError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
result = MCPToolManageService._reconnect_with_url(
server_url="https://example.com/mcp",
provider=mcp_provider,
headers={},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)
# Assert: Verify the expected outcomes
@ -1404,18 +1406,17 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise connection error
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
from core.mcp.error import MCPError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
service._reconnect_provider(
MCPToolManageService._reconnect_with_url(
server_url="https://example.com/mcp",
provider=mcp_provider,
headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)

View File

@ -0,0 +1,420 @@
from types import SimpleNamespace
from unittest.mock import ANY, Mock, patch
import pytest
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentMessageEvent,
QueueErrorEvent,
QueueLLMChunkEvent,
QueueMessageEndEvent,
QueueMessageFileEvent,
QueuePingEvent,
)
from core.app.entities.task_entities import (
EasyUITaskState,
ErrorStreamResponse,
MessageEndStreamResponse,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
PingStreamResponse,
StreamEvent,
)
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.ops.ops_trace_manager import TraceQueueManager
from models.model import AppMode
class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse:
"""Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method."""
@pytest.fixture
def mock_application_generate_entity(self):
"""Create a mock application generate entity."""
entity = Mock(spec=ChatAppGenerateEntity)
entity.task_id = "test-task-id"
entity.app_id = "test-app-id"
# minimal app_config used by pipeline internals
entity.app_config = SimpleNamespace(
tenant_id="test-tenant-id",
app_id="test-app-id",
app_mode=AppMode.CHAT,
app_model_config_dict={},
additional_features=None,
sensitive_word_avoidance=None,
)
# minimal model_conf for LLMResult init
entity.model_conf = SimpleNamespace(
model="test-model",
provider_model_bundle=SimpleNamespace(model_type_instance=Mock()),
credentials={},
)
return entity
@pytest.fixture
def mock_queue_manager(self):
"""Create a mock queue manager."""
manager = Mock(spec=AppQueueManager)
return manager
@pytest.fixture
def mock_message_cycle_manager(self):
"""Create a mock message cycle manager."""
manager = Mock()
manager.get_message_event_type.return_value = StreamEvent.MESSAGE
manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse)
manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse)
manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse)
manager.handle_retriever_resources = Mock()
manager.handle_annotation_reply.return_value = None
return manager
@pytest.fixture
def mock_conversation(self):
"""Create a mock conversation."""
conversation = Mock()
conversation.id = "test-conversation-id"
conversation.mode = "chat"
return conversation
@pytest.fixture
def mock_message(self):
"""Create a mock message."""
message = Mock()
message.id = "test-message-id"
message.created_at = Mock()
message.created_at.timestamp.return_value = 1234567890
return message
@pytest.fixture
def mock_task_state(self):
"""Create a mock task state."""
task_state = Mock(spec=EasyUITaskState)
# Create LLM result mock
llm_result = Mock(spec=RuntimeLLMResult)
llm_result.prompt_messages = []
llm_result.message = Mock()
llm_result.message.content = ""
task_state.llm_result = llm_result
task_state.answer = ""
return task_state
@pytest.fixture
def pipeline(
self,
mock_application_generate_entity,
mock_queue_manager,
mock_conversation,
mock_message,
mock_message_cycle_manager,
mock_task_state,
):
"""Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies."""
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state
):
pipeline = EasyUIBasedGenerateTaskPipeline(
application_generate_entity=mock_application_generate_entity,
queue_manager=mock_queue_manager,
conversation=mock_conversation,
message=mock_message,
stream=True,
)
pipeline._message_cycle_manager = mock_message_cycle_manager
pipeline._task_state = mock_task_state
return pipeline
def test_get_message_event_type_called_once_when_first_llm_chunk_arrives(
self, pipeline, mock_message_cycle_manager
):
"""Expect get_message_event_type to be called when processing the first LLM chunk event."""
# Setup a minimal LLM chunk event
chunk = Mock()
chunk.delta.message.content = "hi"
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
# Execute
list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id")
def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of LLM chunk events with text content."""
# Setup
chunk = Mock()
chunk.delta.message.content = "Hello, world!"
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
assert mock_task_state.llm_result.message.content == "Hello, world!"
def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of LLM chunk events with list content."""
# Setup
text_content = Mock(spec=TextPromptMessageContent)
text_content.data = "Hello"
chunk = Mock()
chunk.delta.message.content = [text_content, " world!"]
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
assert mock_task_state.llm_result.message.content == "Hello world!"
def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of agent message events."""
# Setup
chunk = Mock()
chunk.delta.message.content = "Agent response"
agent_message_event = Mock(spec=QueueAgentMessageEvent)
agent_message_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = agent_message_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
# Ensure method under assertion is a mock to track calls
pipeline._agent_message_to_stream_response = Mock(return_value=Mock())
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
# Agent messages should use _agent_message_to_stream_response
pipeline._agent_message_to_stream_response.assert_called_once_with(
answer="Agent response", message_id="test-message-id"
)
def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of message end events."""
# Setup
llm_result = Mock(spec=RuntimeLLMResult)
llm_result.message = Mock()
llm_result.message.content = "Final response"
message_end_event = Mock(spec=QueueMessageEndEvent)
message_end_event.llm_result = llm_result
mock_queue_message = Mock()
mock_queue_message.event = message_end_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline._save_message = Mock()
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
assert mock_task_state.llm_result == llm_result
pipeline._save_message.assert_called_once()
pipeline._message_end_to_stream_response.assert_called_once()
def test_error_event(self, pipeline):
"""Test handling of error events."""
# Setup
error_event = Mock(spec=QueueErrorEvent)
error_event.error = Exception("Test error")
mock_queue_message = Mock()
mock_queue_message.event = error_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.handle_error = Mock(return_value=Exception("Test error"))
pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
pipeline.handle_error.assert_called_once()
pipeline.error_to_stream_response.assert_called_once()
def test_ping_event(self, pipeline):
"""Test handling of ping events."""
# Setup
ping_event = Mock(spec=QueuePingEvent)
mock_queue_message = Mock()
mock_queue_message.event = ping_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
pipeline.ping_stream_response.assert_called_once()
def test_file_event(self, pipeline, mock_message_cycle_manager):
"""Test handling of file events."""
# Setup
file_event = Mock(spec=QueueMessageFileEvent)
file_event.message_file_id = "file-id"
mock_queue_message = Mock()
mock_queue_message.event = file_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
file_response = Mock(spec=MessageFileStreamResponse)
mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
assert responses[0] == file_response
mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event)
def test_publisher_is_called_with_messages(self, pipeline):
"""Test that publisher publishes messages when provided."""
# Setup
publisher = Mock(spec=AppGeneratorTTSPublisher)
ping_event = Mock(spec=QueuePingEvent)
mock_queue_message = Mock()
mock_queue_message.event = ping_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
list(pipeline._process_stream_response(publisher=publisher, trace_manager=None))
# Assert
# Called once with message and once with None at the end
assert publisher.publish.call_count == 2
publisher.publish.assert_any_call(mock_queue_message)
publisher.publish.assert_any_call(None)
def test_trace_manager_passed_to_save_message(self, pipeline):
"""Test that trace manager is passed to _save_message."""
# Setup
trace_manager = Mock(spec=TraceQueueManager)
message_end_event = Mock(spec=QueueMessageEndEvent)
message_end_event.llm_result = None
mock_queue_message = Mock()
mock_queue_message.event = message_end_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline._save_message = Mock()
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager))
# Assert
pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager)
def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling multiple events in sequence."""
# Setup
chunk1 = Mock()
chunk1.delta.message.content = "Hello"
chunk1.prompt_messages = []
chunk2 = Mock()
chunk2.delta.message.content = " world!"
chunk2.prompt_messages = []
llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event1.chunk = chunk1
ping_event = Mock(spec=QueuePingEvent)
llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event2.chunk = chunk2
mock_queue_messages = [
Mock(event=llm_chunk_event1),
Mock(event=ping_event),
Mock(event=llm_chunk_event2),
]
pipeline.queue_manager.listen.return_value = mock_queue_messages
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 3
assert mock_task_state.llm_result.message.content == "Hello world!"
# Verify calls to message_to_stream_response
assert mock_message_cycle_manager.message_to_stream_response.call_count == 2
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)

View File

@ -0,0 +1,166 @@
"""Unit tests for the message cycle manager optimization."""
from types import SimpleNamespace
from unittest.mock import ANY, Mock, patch
import pytest
from flask import current_app
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
class TestMessageCycleManagerOptimization:
"""Test cases for the message cycle manager optimization that prevents N+1 queries."""
@pytest.fixture
def mock_application_generate_entity(self):
"""Create a mock application generate entity."""
entity = Mock()
entity.task_id = "test-task-id"
return entity
@pytest.fixture
def message_cycle_manager(self, mock_application_generate_entity):
"""Create a message cycle manager instance."""
task_state = Mock()
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Execute
with current_app.app_context():
result = message_cycle_manager.get_message_event_type("test-message-id")
# Assert
assert result == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE when message has no files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and no message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None
# Execute
with current_app.app_context():
result = message_cycle_manager.get_message_event_type("test-message-id")
# Assert
assert result == StreamEvent.MESSAGE
mock_session.query.return_value.scalar.assert_called_once()
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Execute: compute event type once, then pass to message_to_stream_response
with current_app.app_context():
event_type = message_cycle_manager.get_message_event_type("test-message-id")
result = message_cycle_manager.message_to_stream_response(
answer="Hello world", message_id="test-message-id", event_type=event_type
)
# Assert
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
"""Test that message_to_stream_response skips database query when event_type is provided."""
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
# Execute with event_type provided
result = message_cycle_manager.message_to_stream_response(
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
# Assert
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE
# Should not query database when event_type is provided
mock_session_class.assert_not_called()
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
"""Test message_to_stream_response with from_variable_selector parameter."""
result = message_cycle_manager.message_to_stream_response(
answer="Hello world",
message_id="test-message-id",
from_variable_selector=["var1", "var2"],
event_type=StreamEvent.MESSAGE,
)
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.from_variable_selector == ["var1", "var2"]
assert result.event == StreamEvent.MESSAGE
def test_optimization_usage_example(self, message_cycle_manager):
"""Test the optimization pattern that should be used by callers."""
# Step 1: Get event type once (this queries database)
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None # No files
with current_app.app_context():
event_type = message_cycle_manager.get_message_event_type("test-message-id")
# Should query database once
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
assert event_type == StreamEvent.MESSAGE
# Step 2: Use event_type for multiple calls (no additional queries)
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
mock_session_class.return_value.__enter__.return_value = Mock()
chunk1_response = message_cycle_manager.message_to_stream_response(
answer="Chunk 1", message_id="test-message-id", event_type=event_type
)
chunk2_response = message_cycle_manager.message_to_stream_response(
answer="Chunk 2", message_id="test-message-id", event_type=event_type
)
# Should not query database again
mock_session_class.assert_not_called()
assert chunk1_response.event == StreamEvent.MESSAGE
assert chunk2_response.event == StreamEvent.MESSAGE
assert chunk1_response.answer == "Chunk 1"
assert chunk2_response.answer == "Chunk 2"

View File

@ -132,3 +132,36 @@ def test_extract_images_from_docx(monkeypatch):
# DB interactions should be recorded
assert len(db_stub.session.added) == 2
assert db_stub.session.committed is True
def test_extract_images_from_docx_uses_internal_files_url():
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
# Test the URL generation logic directly
from configs import dify_config
# Mock the configuration values
original_files_url = getattr(dify_config, "FILES_URL", None)
original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None)
try:
# Set both URLs - INTERNAL should take precedence
dify_config.FILES_URL = "http://external.example.com"
dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001"
# Test the URL generation logic (same as in word_extractor.py)
upload_file_id = "test_file_id"
# This is the pattern we fixed in the word extractor
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
generated_url = f"{base_url}/files/{upload_file_id}/file-preview"
# Verify that INTERNAL_FILES_URL is used instead of FILES_URL
assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}"
assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}"
finally:
# Restore original values
if original_files_url is not None:
dify_config.FILES_URL = original_files_url
if original_internal_files_url is not None:
dify_config.INTERNAL_FILES_URL = original_internal_files_url

View File

@ -901,6 +901,13 @@ class TestFixedRecursiveCharacterTextSplitter:
# Verify no empty chunks
assert all(len(chunk) > 0 for chunk in result)
def test_double_slash_n(self):
data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2."
separator = "\\n\\n---\\n\\n"
splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator)
chunks = splitter.split_text(data)
assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."]
# ============================================================================
# Test Metadata Preservation

View File

@ -1,3 +1,5 @@
import pytest
from core.workflow.nodes.http_request import (
BodyData,
HttpRequestNodeAuthorization,
@ -5,6 +7,7 @@ from core.workflow.nodes.http_request import (
HttpRequestNodeData,
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
@ -348,3 +351,127 @@ def test_init_params():
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
executor._init_params()
assert executor.params == [("key1", "value1"), ("key2", "value2")]
def test_empty_api_key_raises_error_bearer():
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "bearer", "api_key": ""},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_empty_api_key_raises_error_basic():
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "basic", "api_key": ""},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_empty_api_key_raises_error_custom():
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "custom", "api_key": "", "header": "X-Custom-Auth"},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_whitespace_only_api_key_raises_error():
"""Test that whitespace-only API key raises AuthorizationConfigError."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "bearer", "api_key": " "},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_valid_api_key_works():
"""Test that valid API key works correctly for bearer auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "bearer", "api_key": "valid-api-key-123"},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
executor = Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
# Should not raise an error
headers = executor._assembling_headers()
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer valid-api-key-123"

View File

@ -1,3 +1,4 @@
import json
import time
import pytest
@ -46,14 +47,16 @@ def make_start_node(user_inputs, variables):
def test_json_object_valid_schema():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age"],
}
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age"],
}
)
variables = [
VariableEntity(
@ -65,7 +68,7 @@ def test_json_object_valid_schema():
)
]
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
node = make_start_node(user_inputs, variables)
result = node._run()
@ -74,12 +77,23 @@ def test_json_object_valid_schema():
def test_json_object_invalid_json_string():
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
json_schema=schema,
)
]
@ -88,38 +102,21 @@ def test_json_object_invalid_json_string():
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="profile must be a JSON object"):
node._run()
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
def test_json_object_valid_json_but_not_object(value):
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
)
]
user_inputs = {"profile": value}
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="profile must be a JSON object"):
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
node._run()
def test_json_object_does_not_match_schema():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
variables = [
VariableEntity(
@ -132,7 +129,7 @@ def test_json_object_does_not_match_schema():
]
# age is a string, which violates the schema (expects number)
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
node = make_start_node(user_inputs, variables)
@ -141,14 +138,16 @@ def test_json_object_does_not_match_schema():
def test_json_object_missing_required_schema_field():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
variables = [
VariableEntity(
@ -161,7 +160,7 @@ def test_json_object_missing_required_schema_field():
]
# Missing required field "name"
user_inputs = {"profile": {"age": 20}}
user_inputs = {"profile": json.dumps({"age": 20})}
node = make_start_node(user_inputs, variables)
@ -214,7 +213,7 @@ def test_json_object_optional_variable_not_provided():
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=False,
required=True,
)
]
@ -223,5 +222,5 @@ def test_json_object_optional_variable_not_provided():
node = make_start_node(user_inputs, variables)
# Current implementation raises a validation error even when the variable is optional
with pytest.raises(ValueError, match="profile must be a JSON object"):
with pytest.raises(ValueError, match="profile is required in input form"):
node._run()

View File

@ -1156,6 +1156,199 @@ class TestBillingServiceEdgeCases:
assert "Only team owner or team admin can perform this action" in str(exc_info.value)
class TestBillingServiceSubscriptionOperations:
"""Unit tests for subscription operations in BillingService.
Tests cover:
- Bulk plan retrieval with chunking
- Expired subscription cleanup whitelist retrieval
"""
@pytest.fixture
def mock_send_request(self):
"""Mock _send_request method."""
with patch.object(BillingService, "_send_request") as mock:
yield mock
def test_get_plan_bulk_with_empty_list(self, mock_send_request):
"""Test bulk plan retrieval with empty tenant list."""
# Arrange
tenant_ids = []
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert result == {}
mock_send_request.assert_not_called()
def test_get_plan_bulk_with_chunking(self, mock_send_request):
"""Test bulk plan retrieval with more than 200 tenants (chunking logic)."""
# Arrange - 250 tenants to test chunking (chunk_size = 200)
tenant_ids = [f"tenant-{i}" for i in range(250)]
# First chunk: tenants 0-199
first_chunk_response = {
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
}
# Second chunk: tenants 200-249
second_chunk_response = {
"data": {f"tenant-{i}": {"plan": "professional", "expiration_date": 1767225600} for i in range(200, 250)}
}
mock_send_request.side_effect = [first_chunk_response, second_chunk_response]
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert len(result) == 250
assert result["tenant-0"]["plan"] == "sandbox"
assert result["tenant-199"]["plan"] == "sandbox"
assert result["tenant-200"]["plan"] == "professional"
assert result["tenant-249"]["plan"] == "professional"
assert mock_send_request.call_count == 2
# Verify first chunk call
first_call = mock_send_request.call_args_list[0]
assert first_call[0][0] == "POST"
assert first_call[0][1] == "/subscription/plan/batch"
assert len(first_call[1]["json"]["tenant_ids"]) == 200
# Verify second chunk call
second_call = mock_send_request.call_args_list[1]
assert len(second_call[1]["json"]["tenant_ids"]) == 50
def test_get_plan_bulk_with_partial_batch_failure(self, mock_send_request):
"""Test bulk plan retrieval when one batch fails but others succeed."""
# Arrange - 250 tenants, second batch will fail
tenant_ids = [f"tenant-{i}" for i in range(250)]
# First chunk succeeds
first_chunk_response = {
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
}
# Second chunk fails - need to create a mock that raises when called
def side_effect_func(*args, **kwargs):
if mock_send_request.call_count == 1:
return first_chunk_response
else:
raise ValueError("API error")
mock_send_request.side_effect = side_effect_func
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert - should only have data from first batch
assert len(result) == 200
assert result["tenant-0"]["plan"] == "sandbox"
assert result["tenant-199"]["plan"] == "sandbox"
assert "tenant-200" not in result
assert mock_send_request.call_count == 2
def test_get_plan_bulk_with_all_batches_failing(self, mock_send_request):
"""Test bulk plan retrieval when all batches fail."""
# Arrange
tenant_ids = [f"tenant-{i}" for i in range(250)]
# All chunks fail
def side_effect_func(*args, **kwargs):
raise ValueError("API error")
mock_send_request.side_effect = side_effect_func
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert - should return empty dict
assert result == {}
assert mock_send_request.call_count == 2
def test_get_plan_bulk_with_exactly_200_tenants(self, mock_send_request):
"""Test bulk plan retrieval with exactly 200 tenants (boundary condition)."""
# Arrange
tenant_ids = [f"tenant-{i}" for i in range(200)]
mock_send_request.return_value = {
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
}
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert len(result) == 200
assert mock_send_request.call_count == 1
def test_get_plan_bulk_with_empty_data_response(self, mock_send_request):
"""Test bulk plan retrieval with empty data in response."""
# Arrange
tenant_ids = ["tenant-1", "tenant-2"]
mock_send_request.return_value = {"data": {}}
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert result == {}
def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request):
"""Test successful retrieval of expired subscription cleanup whitelist."""
# Arrange
api_response = [
{
"created_at": "2025-10-16T01:56:17",
"tenant_id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6",
"contact": "example@dify.ai",
"id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe5",
"expired_at": "2026-01-01T01:56:17",
"updated_at": "2025-10-16T01:56:17",
},
{
"created_at": "2025-10-16T02:00:00",
"tenant_id": "tenant-2",
"contact": "test@example.com",
"id": "whitelist-id-2",
"expired_at": "2026-02-01T00:00:00",
"updated_at": "2025-10-16T02:00:00",
},
{
"created_at": "2025-10-16T03:00:00",
"tenant_id": "tenant-3",
"contact": "another@example.com",
"id": "whitelist-id-3",
"expired_at": "2026-03-01T00:00:00",
"updated_at": "2025-10-16T03:00:00",
},
]
mock_send_request.return_value = {"data": api_response}
# Act
result = BillingService.get_expired_subscription_cleanup_whitelist()
# Assert - should return only tenant_ids
assert result == ["36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", "tenant-2", "tenant-3"]
assert len(result) == 3
assert result[0] == "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6"
assert result[1] == "tenant-2"
assert result[2] == "tenant-3"
mock_send_request.assert_called_once_with("GET", "/subscription/cleanup/whitelist")
def test_get_expired_subscription_cleanup_whitelist_empty_list(self, mock_send_request):
"""Test retrieval of empty cleanup whitelist."""
# Arrange
mock_send_request.return_value = {"data": []}
# Act
result = BillingService.get_expired_subscription_cleanup_whitelist()
# Assert
assert result == []
assert len(result) == 0
class TestBillingServiceIntegrationScenarios:
"""Integration-style tests simulating real-world usage scenarios.

View File

@ -0,0 +1,520 @@
"""
Unit tests for document indexing sync task.
This module tests the document indexing sync task functionality including:
- Syncing Notion documents when updated
- Validating document and data source existence
- Credential validation and retrieval
- Cleaning old segments before re-indexing
- Error handling and edge cases
"""
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_sync_task import document_indexing_sync_task
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def tenant_id():
"""Generate a unique tenant ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def dataset_id():
"""Generate a unique dataset ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def document_id():
"""Generate a unique document ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def notion_workspace_id():
"""Generate a Notion workspace ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def notion_page_id():
"""Generate a Notion page ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def credential_id():
"""Generate a credential ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def mock_dataset(dataset_id, tenant_id):
"""Create a mock Dataset object."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.indexing_technique = "high_quality"
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset
@pytest.fixture
def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id):
"""Create a mock Document object with Notion data source."""
doc = Mock(spec=Document)
doc.id = document_id
doc.dataset_id = dataset_id
doc.tenant_id = tenant_id
doc.data_source_type = "notion_import"
doc.indexing_status = "completed"
doc.error = None
doc.stopped_at = None
doc.processing_started_at = None
doc.doc_form = "text_model"
doc.data_source_info_dict = {
"notion_workspace_id": notion_workspace_id,
"notion_page_id": notion_page_id,
"type": "page",
"last_edited_time": "2024-01-01T00:00:00Z",
"credential_id": credential_id,
}
return doc
@pytest.fixture
def mock_document_segments(document_id):
"""Create mock DocumentSegment objects."""
segments = []
for i in range(3):
segment = Mock(spec=DocumentSegment)
segment.id = str(uuid.uuid4())
segment.document_id = document_id
segment.index_node_id = f"node-{document_id}-{i}"
segments.append(segment)
return segments
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_session.scalars.return_value = MagicMock()
yield mock_session
@pytest.fixture
def mock_datasource_provider_service():
"""Mock DatasourceProviderService."""
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
mock_service = MagicMock()
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
mock_service_class.return_value = mock_service
yield mock_service
@pytest.fixture
def mock_notion_extractor():
"""Mock NotionExtractor."""
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
mock_extractor = MagicMock()
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time
mock_extractor_class.return_value = mock_extractor
yield mock_extractor
@pytest.fixture
def mock_index_processor_factory():
"""Mock IndexProcessorFactory."""
with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_processor.clean = Mock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
yield mock_factory
@pytest.fixture
def mock_indexing_runner():
"""Mock IndexingRunner."""
with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class:
mock_runner = MagicMock(spec=IndexingRunner)
mock_runner.run = Mock()
mock_runner_class.return_value = mock_runner
yield mock_runner
# ============================================================================
# Tests for document_indexing_sync_task
# ============================================================================
class TestDocumentIndexingSyncTask:
"""Tests for the document_indexing_sync_task function."""
def test_document_not_found(self, mock_db_session, dataset_id, document_id):
"""Test that task handles document not found gracefully."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = None
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_db_session.close.assert_called_once()
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_workspace_id is missing."""
# Arrange
mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_page_id is missing."""
# Arrange
mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when data_source_info is empty."""
# Arrange
mock_document.data_source_info_dict = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_credential_not_found(
self,
mock_db_session,
mock_datasource_provider_service,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles missing credentials by updating document status."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_datasource_provider_service.get_datasource_credentials.return_value = None
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
assert mock_document.indexing_status == "error"
assert "Datasource credential not found" in mock_document.error
assert mock_document.stopped_at is not None
mock_db_session.commit.assert_called()
mock_db_session.close.assert_called()
def test_page_not_updated(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
):
"""Test that task does nothing when page has not been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Return same time as stored in document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Document status should remain unchanged
assert mock_document.indexing_status == "completed"
# No session operations should be performed beyond the initial query
mock_db_session.close.assert_not_called()
def test_successful_sync_when_page_updated(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test successful sync flow when Notion page has been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# NotionExtractor returns updated time
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Verify document status was updated to parsing
assert mock_document.indexing_status == "parsing"
assert mock_document.processing_started_at is not None
# Verify segments were cleaned
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_called_once()
# Verify segments were deleted from database
for segment in mock_document_segments:
mock_db_session.delete.assert_any_call(segment)
# Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document])
# Verify session operations
assert mock_db_session.commit.called
mock_db_session.close.assert_called_once()
def test_dataset_not_found_during_cleaning(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles dataset not found during cleaning phase."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Document should still be set to parsing
assert mock_document.indexing_status == "parsing"
# Session should be closed after error
mock_db_session.close.assert_called_once()
def test_cleaning_error_continues_to_indexing(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
dataset_id,
document_id,
):
"""Test that indexing continues even if cleaning fails."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Indexing should still be attempted despite cleaning error
mock_indexing_runner.run.assert_called_once_with([mock_document])
mock_db_session.close.assert_called_once()
def test_indexing_runner_document_paused_error(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that DocumentIsPausedError is handled gracefully."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Session should be closed after handling error
mock_db_session.close.assert_called_once()
def test_indexing_runner_general_error(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that general exceptions during indexing are handled."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = Exception("Indexing error")
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Session should be closed after error
mock_db_session.close.assert_called_once()
def test_notion_extractor_initialized_with_correct_params(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
notion_workspace_id,
notion_page_id,
):
"""Test that NotionExtractor is initialized with correct parameters."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update
# Act
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
mock_extractor = MagicMock()
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
mock_extractor_class.return_value = mock_extractor
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_extractor_class.assert_called_once_with(
notion_workspace_id=notion_workspace_id,
notion_obj_id=notion_page_id,
notion_page_type="page",
notion_access_token="test_token",
tenant_id=mock_document.tenant_id,
)
def test_datasource_credentials_requested_correctly(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
credential_id,
):
"""Test that datasource credentials are requested with correct parameters."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
tenant_id=mock_document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
def test_credential_id_missing_uses_none(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles missing credential_id by passing None."""
# Arrange
mock_document.data_source_info_dict = {
"notion_workspace_id": "ws123",
"notion_page_id": "page123",
"type": "page",
"last_edited_time": "2024-01-01T00:00:00Z",
}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
tenant_id=mock_document.tenant_id,
credential_id=None,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
def test_index_processor_clean_called_with_correct_params(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that index processor clean is called with correct parameters."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
expected_node_ids = [seg.index_node_id for seg in mock_document_segments]
mock_processor.clean.assert_called_once_with(
mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True
)

View File

@ -37,6 +37,7 @@ show_help() {
echo " pipeline - Standard pipeline tasks"
echo " triggered_workflow_dispatcher - Trigger dispatcher tasks"
echo " trigger_refresh_executor - Trigger refresh tasks"
echo " retention - Retention tasks"
}
# Parse command line arguments
@ -105,10 +106,10 @@ if [[ -z "${QUEUES}" ]]; then
# Configure queues based on edition
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
else
# Community edition (SELF_HOSTED): dataset and workflow have separate queues
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
fi
echo "No queues specified, using edition-based defaults: ${QUEUES}"

View File

@ -1369,7 +1369,10 @@ PLUGIN_STDIO_BUFFER_SIZE=1024
PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880
PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120
# Plugin Daemon side timeout (configure to match the API side below)
PLUGIN_MAX_EXECUTION_TIMEOUT=600
# API side timeout (configure to match the Plugin Daemon side above)
PLUGIN_DAEMON_TIMEOUT=600.0
# PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple
PIP_MIRROR_URL=
@ -1479,4 +1482,9 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
ANNOTATION_IMPORT_MAX_CONCURRENT=5
# The API key of amplitude
AMPLITUDE_API_KEY=
AMPLITUDE_API_KEY=
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@ -34,6 +34,7 @@ services:
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
depends_on:
init_permissions:

View File

@ -591,6 +591,7 @@ x-shared-env: &shared-api-worker-env
PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880}
PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120}
PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600}
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local}
PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage}
@ -663,6 +664,9 @@ x-shared-env: &shared-api-worker-env
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:-20}
ANNOTATION_IMPORT_MAX_CONCURRENT: ${ANNOTATION_IMPORT_MAX_CONCURRENT:-5}
AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-}
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21}
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000}
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30}
services:
# Init container to fix permissions
@ -699,6 +703,7 @@ services:
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
depends_on:
init_permissions:

71
web/__mocks__/ky.ts Normal file
View File

@ -0,0 +1,71 @@
/**
* Mock for ky HTTP client
* This mock is used to avoid ESM issues in Jest tests
*/
type KyResponse = {
ok: boolean
status: number
statusText: string
headers: Headers
json: jest.Mock
text: jest.Mock
blob: jest.Mock
arrayBuffer: jest.Mock
clone: jest.Mock
}
type KyInstance = jest.Mock & {
get: jest.Mock
post: jest.Mock
put: jest.Mock
patch: jest.Mock
delete: jest.Mock
head: jest.Mock
create: jest.Mock
extend: jest.Mock
stop: symbol
}
const createResponse = (data: unknown = {}, status = 200): KyResponse => {
const response: KyResponse = {
ok: status >= 200 && status < 300,
status,
statusText: status === 200 ? 'OK' : 'Error',
headers: new Headers(),
json: jest.fn().mockResolvedValue(data),
text: jest.fn().mockResolvedValue(JSON.stringify(data)),
blob: jest.fn().mockResolvedValue(new Blob()),
arrayBuffer: jest.fn().mockResolvedValue(new ArrayBuffer(0)),
clone: jest.fn(),
}
// Ensure clone returns a new response-like object, not the same instance
response.clone.mockImplementation(() => createResponse(data, status))
return response
}
const createKyInstance = (): KyInstance => {
const instance = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) as KyInstance
// HTTP methods
instance.get = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
instance.post = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
instance.put = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
instance.patch = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
instance.delete = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
instance.head = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
// Create new instance with custom options
instance.create = jest.fn().mockImplementation(() => createKyInstance())
instance.extend = jest.fn().mockImplementation(() => createKyInstance())
// Stop method for AbortController
instance.stop = Symbol('stop')
return instance
}
const ky = createKyInstance()
export default ky
export { ky }

View File

@ -16,7 +16,7 @@ import {
import { useTranslation } from 'react-i18next'
import { useShallow } from 'zustand/react/shallow'
import s from './style.module.css'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useStore } from '@/app/components/app/store'
import AppSideBar from '@/app/components/app-sidebar'
import type { NavIcon } from '@/app/components/app-sidebar/navLink'

View File

@ -3,7 +3,7 @@ import { RiCalendarLine } from '@remixicon/react'
import type { Dayjs } from 'dayjs'
import type { FC } from 'react'
import React, { useCallback } from 'react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { formatToLocalTime } from '@/utils/format'
import { useI18N } from '@/context/i18n'
import Picker from '@/app/components/base/date-and-time-picker/date-picker'

View File

@ -6,7 +6,7 @@ import { SimpleSelect } from '@/app/components/base/select'
import type { Item } from '@/app/components/base/select'
import dayjs from 'dayjs'
import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useTranslation } from 'react-i18next'
const today = dayjs()

View File

@ -4,7 +4,7 @@ import React, { useCallback, useRef, useState } from 'react'
import type { PopupProps } from './config-popup'
import ConfigPopup from './config-popup'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import {
PortalToFollowElem,
PortalToFollowElemContent,

View File

@ -12,7 +12,7 @@ import Indicator from '@/app/components/header/indicator'
import Switch from '@/app/components/base/switch'
import Tooltip from '@/app/components/base/tooltip'
import Divider from '@/app/components/base/divider'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
const I18N_PREFIX = 'app.tracing'

View File

@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Input from '@/app/components/base/input'
type Props = {

View File

@ -12,7 +12,7 @@ import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangS
import { TracingProvider } from './type'
import TracingIcon from './tracing-icon'
import ConfigButton from './config-button'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
import Indicator from '@/app/components/header/indicator'
import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps'

View File

@ -6,7 +6,7 @@ import {
} from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { TracingProvider } from './type'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'

View File

@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { TracingIcon as Icon } from '@/app/components/base/icons/src/public/tracing'
type Props = {

View File

@ -23,7 +23,7 @@ import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use
import useDocumentTitle from '@/hooks/use-document-title'
import ExtraInfo from '@/app/components/datasets/extra-info'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
export type IAppDetailLayoutProps = {
children: React.ReactNode

View File

@ -1,7 +1,7 @@
'use client'
import Header from '@/app/signin/_header'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
export default function SignInLayout({ children }: any) {

View File

@ -2,7 +2,7 @@
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useRouter, useSearchParams } from 'next/navigation'
import cn from 'classnames'
import { cn } from '@/utils/classnames'
import { RiCheckboxCircleFill } from '@remixicon/react'
import { useCountDown } from 'ahooks'
import Button from '@/app/components/base/button'

View File

@ -1,6 +1,6 @@
'use client'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
import useDocumentTitle from '@/hooks/use-document-title'
import type { PropsWithChildren } from 'react'

View File

@ -7,7 +7,7 @@ import Loading from '@/app/components/base/loading'
import MailAndCodeAuth from './components/mail-and-code-auth'
import MailAndPasswordAuth from './components/mail-and-password-auth'
import SSOAuth from './components/sso-auth'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { LicenseStatus } from '@/types/feature'
import { IS_CE_EDITION } from '@/config'
import { useGlobalPublicStore } from '@/context/global-public-context'

View File

@ -1,7 +1,7 @@
'use client'
import Header from '@/app/signin/_header'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
import useDocumentTitle from '@/hooks/use-document-title'
import { AppContextProvider } from '@/context/app-context'

View File

@ -1,13 +1,13 @@
'use client'
import { useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import useSWR from 'swr'
import { useRouter, useSearchParams } from 'next/navigation'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Button from '@/app/components/base/button'
import { invitationCheck } from '@/service/common'
import Loading from '@/app/components/base/loading'
import useDocumentTitle from '@/hooks/use-document-title'
import { useInvitationCheck } from '@/service/use-common'
const ActivateForm = () => {
useDocumentTitle('')
@ -26,19 +26,21 @@ const ActivateForm = () => {
token,
},
}
const { data: checkRes } = useSWR(checkParams, invitationCheck, {
revalidateOnFocus: false,
onSuccess(data) {
if (data.is_valid) {
const params = new URLSearchParams(searchParams)
const { email, workspace_id } = data.data
params.set('email', encodeURIComponent(email))
params.set('workspace_id', encodeURIComponent(workspace_id))
params.set('invite_token', encodeURIComponent(token as string))
router.replace(`/signin?${params.toString()}`)
}
},
})
const { data: checkRes } = useInvitationCheck({
...checkParams.params,
token: token || undefined,
}, true)
useEffect(() => {
if (checkRes?.is_valid) {
const params = new URLSearchParams(searchParams)
const { email, workspace_id } = checkRes.data
params.set('email', encodeURIComponent(email))
params.set('workspace_id', encodeURIComponent(workspace_id))
params.set('invite_token', encodeURIComponent(token as string))
router.replace(`/signin?${params.toString()}`)
}
}, [checkRes, router, searchParams, token])
return (
<div className={

View File

@ -2,7 +2,7 @@
import React from 'react'
import Header from '../signin/_header'
import ActivateForm from './activateForm'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
const Activate = () => {

View File

@ -29,7 +29,7 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie
import type { Operation } from './app-operations'
import AppOperations from './app-operations'
import dynamic from 'next/dynamic'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { AppModeEnum } from '@/types/app'
const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), {

View File

@ -16,7 +16,7 @@ import AppInfo from './app-info'
import NavLink from './navLink'
import { useStore as useAppStore } from '@/app/components/app/store'
import type { NavIcon } from './navLink'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { AppModeEnum } from '@/types/app'
type Props = {

View File

@ -2,7 +2,7 @@ import React, { useCallback, useState } from 'react'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem'
import ActionButton from '../../base/action-button'
import { RiMoreFill } from '@remixicon/react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Menu from './menu'
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'

View File

@ -0,0 +1,379 @@
import React from 'react'
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import DatasetInfo from './index'
import Dropdown from './dropdown'
import Menu from './menu'
import MenuItem from './menu-item'
import type { DataSet } from '@/models/datasets'
import {
ChunkingMode,
DataSourceType,
DatasetPermission,
} from '@/models/datasets'
import { RETRIEVE_METHOD } from '@/types/app'
import { RiEditLine } from '@remixicon/react'
let mockDataset: DataSet
let mockIsDatasetOperator = false
const mockReplace = jest.fn()
const mockInvalidDatasetList = jest.fn()
const mockInvalidDatasetDetail = jest.fn()
const mockExportPipeline = jest.fn()
const mockCheckIsUsedInApp = jest.fn()
const mockDeleteDataset = jest.fn()
const createDataset = (overrides: Partial<DataSet> = {}): DataSet => ({
id: 'dataset-1',
name: 'Dataset Name',
indexing_status: 'completed',
icon_info: {
icon: '📙',
icon_background: '#FFF4ED',
icon_type: 'emoji',
icon_url: '',
},
description: 'Dataset description',
permission: DatasetPermission.onlyMe,
data_source_type: DataSourceType.FILE,
indexing_technique: 'high_quality' as DataSet['indexing_technique'],
created_by: 'user-1',
updated_by: 'user-1',
updated_at: 1690000000,
app_count: 0,
doc_form: ChunkingMode.text,
document_count: 1,
total_document_count: 1,
word_count: 1000,
provider: 'internal',
embedding_model: 'text-embedding-3',
embedding_model_provider: 'openai',
embedding_available: true,
retrieval_model_dict: {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: 5,
score_threshold_enabled: false,
score_threshold: 0,
},
retrieval_model: {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: 5,
score_threshold_enabled: false,
score_threshold: 0,
},
tags: [],
external_knowledge_info: {
external_knowledge_id: '',
external_knowledge_api_id: '',
external_knowledge_api_name: '',
external_knowledge_api_endpoint: '',
},
external_retrieval_model: {
top_k: 0,
score_threshold: 0,
score_threshold_enabled: false,
},
built_in_field_enabled: false,
runtime_mode: 'rag_pipeline',
enable_api: false,
is_multimodal: false,
...overrides,
})
jest.mock('next/navigation', () => ({
useRouter: () => ({
replace: mockReplace,
}),
}))
jest.mock('@/context/dataset-detail', () => ({
useDatasetDetailContextWithSelector: (selector: (state: { dataset?: DataSet }) => unknown) => selector({ dataset: mockDataset }),
}))
jest.mock('@/context/app-context', () => ({
useSelector: (selector: (state: { isCurrentWorkspaceDatasetOperator: boolean }) => unknown) =>
selector({ isCurrentWorkspaceDatasetOperator: mockIsDatasetOperator }),
}))
jest.mock('@/service/knowledge/use-dataset', () => ({
datasetDetailQueryKeyPrefix: ['dataset', 'detail'],
useInvalidDatasetList: () => mockInvalidDatasetList,
}))
jest.mock('@/service/use-base', () => ({
useInvalid: () => mockInvalidDatasetDetail,
}))
jest.mock('@/service/use-pipeline', () => ({
useExportPipelineDSL: () => ({
mutateAsync: mockExportPipeline,
}),
}))
jest.mock('@/service/datasets', () => ({
checkIsUsedInApp: (...args: unknown[]) => mockCheckIsUsedInApp(...args),
deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args),
}))
jest.mock('@/hooks/use-knowledge', () => ({
useKnowledge: () => ({
formatIndexingTechniqueAndMethod: () => 'indexing-technique',
}),
}))
jest.mock('@/app/components/datasets/rename-modal', () => ({
__esModule: true,
default: ({
show,
onClose,
onSuccess,
}: {
show: boolean
onClose: () => void
onSuccess?: () => void
}) => {
if (!show)
return null
return (
<div data-testid="rename-modal">
<button type="button" onClick={onSuccess}>Success</button>
<button type="button" onClick={onClose}>Close</button>
</div>
)
},
}))
const openMenu = async (user: ReturnType<typeof userEvent.setup>) => {
const trigger = screen.getByRole('button')
await user.click(trigger)
}
describe('DatasetInfo', () => {
beforeEach(() => {
jest.clearAllMocks()
mockDataset = createDataset()
mockIsDatasetOperator = false
})
// Rendering of dataset summary details based on expand and dataset state.
describe('Rendering', () => {
it('should show dataset details when expanded', () => {
// Arrange
mockDataset = createDataset({ is_published: true })
render(<DatasetInfo expand />)
// Assert
expect(screen.getByText('Dataset Name')).toBeInTheDocument()
expect(screen.getByText('Dataset description')).toBeInTheDocument()
expect(screen.getByText('dataset.chunkingMode.general')).toBeInTheDocument()
expect(screen.getByText('indexing-technique')).toBeInTheDocument()
})
it('should show external tag when provider is external', () => {
// Arrange
mockDataset = createDataset({ provider: 'external', is_published: false })
render(<DatasetInfo expand />)
// Assert
expect(screen.getByText('dataset.externalTag')).toBeInTheDocument()
expect(screen.queryByText('dataset.chunkingMode.general')).not.toBeInTheDocument()
})
it('should hide detailed fields when collapsed', () => {
// Arrange
render(<DatasetInfo expand={false} />)
// Assert
expect(screen.queryByText('Dataset Name')).not.toBeInTheDocument()
expect(screen.queryByText('Dataset description')).not.toBeInTheDocument()
})
})
})
describe('MenuItem', () => {
beforeEach(() => {
jest.clearAllMocks()
})
// Event handling for menu item interactions.
describe('Interactions', () => {
it('should call handler when clicked', async () => {
const user = userEvent.setup()
const handleClick = jest.fn()
// Arrange
render(<MenuItem name="Edit" Icon={RiEditLine} handleClick={handleClick} />)
// Act
await user.click(screen.getByText('Edit'))
// Assert
expect(handleClick).toHaveBeenCalledTimes(1)
})
})
})
describe('Menu', () => {
beforeEach(() => {
jest.clearAllMocks()
mockDataset = createDataset()
})
// Rendering of menu options based on runtime mode and delete visibility.
describe('Rendering', () => {
it('should show edit, export, and delete options when rag pipeline and deletable', () => {
// Arrange
mockDataset = createDataset({ runtime_mode: 'rag_pipeline' })
render(
<Menu
showDelete
openRenameModal={jest.fn()}
handleExportPipeline={jest.fn()}
detectIsUsedByApp={jest.fn()}
/>,
)
// Assert
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
expect(screen.getByText('datasetPipeline.operations.exportPipeline')).toBeInTheDocument()
expect(screen.getByText('common.operation.delete')).toBeInTheDocument()
})
it('should hide export and delete options when not rag pipeline and not deletable', () => {
// Arrange
mockDataset = createDataset({ runtime_mode: 'general' })
render(
<Menu
showDelete={false}
openRenameModal={jest.fn()}
handleExportPipeline={jest.fn()}
detectIsUsedByApp={jest.fn()}
/>,
)
// Assert
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
expect(screen.queryByText('datasetPipeline.operations.exportPipeline')).not.toBeInTheDocument()
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
})
})
})
describe('Dropdown', () => {
beforeEach(() => {
jest.clearAllMocks()
mockDataset = createDataset({ pipeline_id: 'pipeline-1', runtime_mode: 'rag_pipeline' })
mockIsDatasetOperator = false
mockExportPipeline.mockResolvedValue({ data: 'pipeline-content' })
mockCheckIsUsedInApp.mockResolvedValue({ is_using: false })
mockDeleteDataset.mockResolvedValue({})
if (!('createObjectURL' in URL)) {
Object.defineProperty(URL, 'createObjectURL', {
value: jest.fn(),
writable: true,
})
}
if (!('revokeObjectURL' in URL)) {
Object.defineProperty(URL, 'revokeObjectURL', {
value: jest.fn(),
writable: true,
})
}
})
// Rendering behavior based on workspace role.
describe('Rendering', () => {
it('should hide delete option when user is dataset operator', async () => {
const user = userEvent.setup()
// Arrange
mockIsDatasetOperator = true
render(<Dropdown expand />)
// Act
await openMenu(user)
// Assert
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
})
})
// User interactions that trigger modals and exports.
describe('Interactions', () => {
it('should open rename modal when edit is clicked', async () => {
const user = userEvent.setup()
// Arrange
render(<Dropdown expand />)
// Act
await openMenu(user)
await user.click(screen.getByText('common.operation.edit'))
// Assert
expect(screen.getByTestId('rename-modal')).toBeInTheDocument()
})
it('should export pipeline when export is clicked', async () => {
const user = userEvent.setup()
const anchorClickSpy = jest.spyOn(HTMLAnchorElement.prototype, 'click')
const createObjectURLSpy = jest.spyOn(URL, 'createObjectURL')
// Arrange
render(<Dropdown expand />)
// Act
await openMenu(user)
await user.click(screen.getByText('datasetPipeline.operations.exportPipeline'))
// Assert
await waitFor(() => {
expect(mockExportPipeline).toHaveBeenCalledWith({
pipelineId: 'pipeline-1',
include: false,
})
})
expect(createObjectURLSpy).toHaveBeenCalledTimes(1)
expect(anchorClickSpy).toHaveBeenCalledTimes(1)
})
it('should show delete confirmation when delete is clicked', async () => {
const user = userEvent.setup()
// Arrange
render(<Dropdown expand />)
// Act
await openMenu(user)
await user.click(screen.getByText('common.operation.delete'))
// Assert
await waitFor(() => {
expect(screen.getByText('dataset.deleteDatasetConfirmContent')).toBeInTheDocument()
})
})
it('should delete dataset and redirect when confirm is clicked', async () => {
const user = userEvent.setup()
// Arrange
render(<Dropdown expand />)
// Act
await openMenu(user)
await user.click(screen.getByText('common.operation.delete'))
await user.click(await screen.findByRole('button', { name: 'common.operation.confirm' }))
// Assert
await waitFor(() => {
expect(mockDeleteDataset).toHaveBeenCalledWith('dataset-1')
})
expect(mockInvalidDatasetList).toHaveBeenCalledTimes(1)
expect(mockReplace).toHaveBeenCalledWith('/datasets')
})
})
})

View File

@ -8,7 +8,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import type { DataSet } from '@/models/datasets'
import { DOC_FORM_TEXT } from '@/models/datasets'
import { useKnowledge } from '@/hooks/use-knowledge'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Dropdown from './dropdown'
type DatasetInfoProps = {

View File

@ -11,7 +11,7 @@ import AppIcon from '../base/app-icon'
import Divider from '../base/divider'
import NavLink from './navLink'
import type { NavIcon } from './navLink'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import Effect from '../base/effect'
import Dropdown from './dataset-info/dropdown'

View File

@ -9,7 +9,7 @@ import AppSidebarDropdown from './app-sidebar-dropdown'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { useStore as useAppStore } from '@/app/components/app/store'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Divider from '../base/divider'
import { useHover, useKeyPress } from 'ahooks'
import ToggleButton from './toggle-button'

View File

@ -2,7 +2,7 @@
import React from 'react'
import { useSelectedLayoutSegment } from 'next/navigation'
import Link from 'next/link'
import classNames from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import type { RemixiconComponentType } from '@remixicon/react'
export type NavIcon = React.ComponentType<
@ -42,7 +42,7 @@ const NavLink = ({
const NavIcon = isActive ? iconMap.selected : iconMap.normal
const renderIcon = () => (
<div className={classNames(mode !== 'expand' && '-ml-1')}>
<div className={cn(mode !== 'expand' && '-ml-1')}>
<NavIcon className="h-4 w-4 shrink-0" aria-hidden="true" />
</div>
)
@ -53,21 +53,17 @@ const NavLink = ({
key={name}
type='button'
disabled
className={classNames(
'system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover',
'pl-3 pr-1',
)}
className={cn('system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover',
'pl-3 pr-1')}
title={mode === 'collapse' ? name : ''}
aria-disabled
>
{renderIcon()}
<span
className={classNames(
'overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
className={cn('overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
mode === 'expand'
? 'ml-2 max-w-none opacity-100'
: 'ml-0 max-w-0 opacity-0',
)}
: 'ml-0 max-w-0 opacity-0')}
>
{name}
</span>
@ -79,22 +75,18 @@ const NavLink = ({
<Link
key={name}
href={href}
className={classNames(
isActive
? 'system-sm-semibold border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only'
: 'system-sm-medium text-components-menu-item-text hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover',
'flex h-8 items-center rounded-lg pl-3 pr-1',
)}
className={cn(isActive
? 'system-sm-semibold border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only'
: 'system-sm-medium text-components-menu-item-text hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover',
'flex h-8 items-center rounded-lg pl-3 pr-1')}
title={mode === 'collapse' ? name : ''}
>
{renderIcon()}
<span
className={classNames(
'overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
className={cn('overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
mode === 'expand'
? 'ml-2 max-w-none opacity-100'
: 'ml-0 max-w-0 opacity-0',
)}
: 'ml-0 max-w-0 opacity-0')}
>
{name}
</span>

View File

@ -1,7 +1,7 @@
import React from 'react'
import Button from '../base/button'
import { RiArrowLeftSLine, RiArrowRightSLine } from '@remixicon/react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Tooltip from '../base/tooltip'
import { useTranslation } from 'react-i18next'
import { getKeyboardKeyNameBySystem } from '../workflow/utils'

View File

@ -2,12 +2,6 @@ import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import EditItem, { EditItemType } from './index'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('AddAnnotationModal/EditItem', () => {
test('should render query inputs with user avatar and placeholder strings', () => {
render(

View File

@ -0,0 +1,42 @@
import React from 'react'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import BatchAction from './batch-action'
describe('BatchAction', () => {
const baseProps = {
selectedIds: ['1', '2', '3'],
onBatchDelete: jest.fn(),
onCancel: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
it('should show the selected count and trigger cancel action', () => {
render(<BatchAction {...baseProps} className='custom-class' />)
expect(screen.getByText('3')).toBeInTheDocument()
expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument()
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
expect(baseProps.onCancel).toHaveBeenCalledTimes(1)
})
it('should confirm before running batch delete', async () => {
const onBatchDelete = jest.fn().mockResolvedValue(undefined)
render(<BatchAction {...baseProps} onBatchDelete={onBatchDelete} />)
fireEvent.click(screen.getByRole('button', { name: 'common.operation.delete' }))
await screen.findByText('appAnnotation.list.delete.title')
await act(async () => {
fireEvent.click(screen.getAllByRole('button', { name: 'common.operation.delete' })[1])
})
await waitFor(() => {
expect(onBatchDelete).toHaveBeenCalledTimes(1)
})
})
})

View File

@ -3,7 +3,7 @@ import { RiDeleteBinLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { useBoolean } from 'ahooks'
import Divider from '@/app/components/base/divider'
import classNames from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Confirm from '@/app/components/base/confirm'
const i18nPrefix = 'appAnnotation.batchAction'
@ -38,7 +38,7 @@ const BatchAction: FC<IBatchActionProps> = ({
setIsNotDeleting()
}
return (
<div className={classNames('pointer-events-none flex w-full justify-center', className)}>
<div className={cn('pointer-events-none flex w-full justify-center', className)}>
<div className='pointer-events-auto flex items-center gap-x-1 rounded-[10px] border border-components-actionbar-border-accent bg-components-actionbar-bg-accent p-1 shadow-xl shadow-shadow-shadow-5 backdrop-blur-[5px]'>
<div className='inline-flex items-center gap-x-2 py-1 pl-2 pr-3'>
<span className='flex h-5 w-5 items-center justify-center rounded-md bg-text-accent px-1 py-0.5 text-xs font-medium text-text-primary-on-surface'>

View File

@ -0,0 +1,72 @@
import React from 'react'
import { render, screen } from '@testing-library/react'
import CSVDownload from './csv-downloader'
import I18nContext from '@/context/i18n'
import { LanguagesSupported } from '@/i18n-config/language'
import type { Locale } from '@/i18n-config'
const downloaderProps: any[] = []
jest.mock('react-papaparse', () => ({
useCSVDownloader: jest.fn(() => ({
CSVDownloader: ({ children, ...props }: any) => {
downloaderProps.push(props)
return <div data-testid="mock-csv-downloader">{children}</div>
},
Type: { Link: 'link' },
})),
}))
const renderWithLocale = (locale: Locale) => {
return render(
<I18nContext.Provider value={{
locale,
i18n: {},
setLocaleOnClient: jest.fn().mockResolvedValue(undefined),
}}
>
<CSVDownload />
</I18nContext.Provider>,
)
}
describe('CSVDownload', () => {
const englishTemplate = [
['question', 'answer'],
['question1', 'answer1'],
['question2', 'answer2'],
]
const chineseTemplate = [
['问题', '答案'],
['问题 1', '答案 1'],
['问题 2', '答案 2'],
]
beforeEach(() => {
downloaderProps.length = 0
})
it('should render the structure preview and pass English template data by default', () => {
renderWithLocale('en-US' as Locale)
expect(screen.getByText('share.generation.csvStructureTitle')).toBeInTheDocument()
expect(screen.getByText('appAnnotation.batchModal.template')).toBeInTheDocument()
expect(downloaderProps[0]).toMatchObject({
filename: 'template-en-US',
type: 'link',
bom: true,
data: englishTemplate,
})
})
it('should switch to the Chinese template when locale matches the secondary language', () => {
const locale = LanguagesSupported[1] as Locale
renderWithLocale(locale)
expect(downloaderProps[0]).toMatchObject({
filename: `template-${locale}`,
data: chineseTemplate,
})
})
})

View File

@ -4,7 +4,7 @@ import React, { useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import { RiDeleteBinLine } from '@remixicon/react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files'
import { ToastContext } from '@/app/components/base/toast'
import Button from '@/app/components/base/button'

View File

@ -0,0 +1,164 @@
import React from 'react'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import BatchModal, { ProcessStatus } from './index'
import { useProviderContext } from '@/context/provider-context'
import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/service/annotation'
import type { IBatchModalProps } from './index'
import Toast from '@/app/components/base/toast'
jest.mock('@/app/components/base/toast', () => ({
__esModule: true,
default: {
notify: jest.fn(),
},
}))
jest.mock('@/service/annotation', () => ({
annotationBatchImport: jest.fn(),
checkAnnotationBatchImportProgress: jest.fn(),
}))
jest.mock('@/context/provider-context', () => ({
useProviderContext: jest.fn(),
}))
jest.mock('./csv-downloader', () => ({
__esModule: true,
default: () => <div data-testid="csv-downloader-stub" />,
}))
let lastUploadedFile: File | undefined
jest.mock('./csv-uploader', () => ({
__esModule: true,
default: ({ file, updateFile }: { file?: File; updateFile: (file?: File) => void }) => (
<div>
<button
data-testid="mock-uploader"
onClick={() => {
lastUploadedFile = new File(['question,answer'], 'batch.csv', { type: 'text/csv' })
updateFile(lastUploadedFile)
}}
>
upload
</button>
{file && <span data-testid="selected-file">{file.name}</span>}
</div>
),
}))
jest.mock('@/app/components/billing/annotation-full', () => ({
__esModule: true,
default: () => <div data-testid="annotation-full" />,
}))
const mockNotify = Toast.notify as jest.Mock
const useProviderContextMock = useProviderContext as jest.Mock
const annotationBatchImportMock = annotationBatchImport as jest.Mock
const checkAnnotationBatchImportProgressMock = checkAnnotationBatchImportProgress as jest.Mock
const renderComponent = (props: Partial<IBatchModalProps> = {}) => {
const mergedProps: IBatchModalProps = {
appId: 'app-id',
isShow: true,
onCancel: jest.fn(),
onAdded: jest.fn(),
...props,
}
return {
...render(<BatchModal {...mergedProps} />),
props: mergedProps,
}
}
describe('BatchModal', () => {
beforeEach(() => {
jest.clearAllMocks()
lastUploadedFile = undefined
useProviderContextMock.mockReturnValue({
plan: {
usage: { annotatedResponse: 0 },
total: { annotatedResponse: 10 },
},
enableBilling: false,
})
})
it('should disable run action and show billing hint when annotation quota is full', () => {
useProviderContextMock.mockReturnValue({
plan: {
usage: { annotatedResponse: 10 },
total: { annotatedResponse: 10 },
},
enableBilling: true,
})
renderComponent()
expect(screen.getByTestId('annotation-full')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })).toBeDisabled()
})
it('should reset uploader state when modal closes and allow manual cancellation', () => {
const { rerender, props } = renderComponent()
fireEvent.click(screen.getByTestId('mock-uploader'))
expect(screen.getByTestId('selected-file')).toHaveTextContent('batch.csv')
rerender(<BatchModal {...props} isShow={false} />)
rerender(<BatchModal {...props} isShow />)
expect(screen.queryByTestId('selected-file')).toBeNull()
fireEvent.click(screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' }))
expect(props.onCancel).toHaveBeenCalledTimes(1)
})
it('should submit the csv file, poll status, and notify when import completes', async () => {
jest.useFakeTimers()
const { props } = renderComponent()
const fileTrigger = screen.getByTestId('mock-uploader')
fireEvent.click(fileTrigger)
const runButton = screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })
expect(runButton).not.toBeDisabled()
annotationBatchImportMock.mockResolvedValue({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING })
checkAnnotationBatchImportProgressMock
.mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING })
.mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.COMPLETED })
await act(async () => {
fireEvent.click(runButton)
})
await waitFor(() => {
expect(annotationBatchImportMock).toHaveBeenCalledTimes(1)
})
const formData = annotationBatchImportMock.mock.calls[0][0].body as FormData
expect(formData.get('file')).toBe(lastUploadedFile)
await waitFor(() => {
expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(1)
})
await act(async () => {
jest.runOnlyPendingTimers()
})
await waitFor(() => {
expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(2)
})
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith({
type: 'success',
message: 'appAnnotation.batchModal.completed',
})
expect(props.onAdded).toHaveBeenCalledTimes(1)
expect(props.onCancel).toHaveBeenCalledTimes(1)
})
jest.useRealTimers()
})
})

View File

@ -245,7 +245,7 @@ describe('EditItem', () => {
expect(mockSave).toHaveBeenCalledWith('Test save content')
})
it('should show delete option when content changes', async () => {
it('should show delete option and restore original content when delete is clicked', async () => {
// Arrange
const mockSave = jest.fn().mockResolvedValue(undefined)
const props = {
@ -267,7 +267,13 @@ describe('EditItem', () => {
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
// Assert
expect(mockSave).toHaveBeenCalledWith('Modified content')
expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content')
expect(await screen.findByText('common.operation.delete')).toBeInTheDocument()
await user.click(screen.getByText('common.operation.delete'))
expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content')
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
})
it('should handle keyboard interactions in edit mode', async () => {
@ -393,5 +399,68 @@ describe('EditItem', () => {
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
expect(screen.getByText('Test content')).toBeInTheDocument()
})
it('should handle save failure gracefully in edit mode', async () => {
// Arrange
const mockSave = jest.fn().mockRejectedValueOnce(new Error('Save failed'))
const props = {
...defaultProps,
onSave: mockSave,
}
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
// Enter edit mode and save (should fail)
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.type(textarea, 'New content')
// Save should fail but not throw
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
// Assert - Should remain in edit mode when save fails
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
expect(mockSave).toHaveBeenCalledWith('New content')
})
it('should handle delete action failure gracefully', async () => {
// Arrange
const mockSave = jest.fn()
.mockResolvedValueOnce(undefined) // First save succeeds
.mockRejectedValueOnce(new Error('Delete failed')) // Delete fails
const props = {
...defaultProps,
onSave: mockSave,
}
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
// Edit content to show delete button
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Modified content')
// Save to create new content
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
await screen.findByText('common.operation.delete')
// Click delete (should fail but not throw)
await user.click(screen.getByText('common.operation.delete'))
// Assert - Delete action should handle error gracefully
expect(mockSave).toHaveBeenCalledTimes(2)
expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content')
expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content')
// When delete fails, the delete button should still be visible (state not changed)
expect(screen.getByText('common.operation.delete')).toBeInTheDocument()
expect(screen.getByText('Modified content')).toBeInTheDocument()
})
})
})

View File

@ -6,7 +6,7 @@ import { RiDeleteBinLine, RiEditFill, RiEditLine } from '@remixicon/react'
import { Robot, User } from '@/app/components/base/icons/src/public/avatar'
import Textarea from '@/app/components/base/textarea'
import Button from '@/app/components/base/button'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
export enum EditItemType {
Query = 'query',
@ -52,8 +52,14 @@ const EditItem: FC<Props> = ({
}, [content])
const handleSave = async () => {
await onSave(newContent)
setIsEdit(false)
try {
await onSave(newContent)
setIsEdit(false)
}
catch {
// Keep edit mode open when save fails
// Error notification is handled by the parent component
}
}
const handleCancel = () => {
@ -96,9 +102,16 @@ const EditItem: FC<Props> = ({
<div className='mr-2'>·</div>
<div
className='flex cursor-pointer items-center space-x-1'
onClick={() => {
setNewContent(content)
onSave(content)
onClick={async () => {
try {
await onSave(content)
// Only update UI state after successful delete
setNewContent(content)
}
catch {
// Delete action failed - error is already handled by parent
// UI state remains unchanged, user can retry
}
}}
>
<div className='h-3.5 w-3.5'>

Some files were not shown because too many files have changed in this diff Show More