diff --git a/.github/scripts/generate-i18n-changes.mjs b/.github/scripts/generate-i18n-changes.mjs new file mode 100644 index 0000000000..3d25115ac3 --- /dev/null +++ b/.github/scripts/generate-i18n-changes.mjs @@ -0,0 +1,82 @@ +import { execFileSync } from 'node:child_process' +import fs from 'node:fs' +import path from 'node:path' + +const repoRoot = process.cwd() +const baseSha = process.env.BASE_SHA || '' +const headSha = process.env.HEAD_SHA || '' +const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean) +const outputPath = process.env.I18N_CHANGES_OUTPUT_PATH || '/tmp/i18n-changes.json' + +const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`) + +const readCurrentJson = (fileStem) => { + const filePath = englishPath(fileStem) + if (!fs.existsSync(filePath)) + return null + + return JSON.parse(fs.readFileSync(filePath, 'utf8')) +} + +const readBaseJson = (fileStem) => { + if (!baseSha) + return null + + try { + const relativePath = `web/i18n/en-US/${fileStem}.json` + const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' }) + return JSON.parse(content) + } + catch { + return null + } +} + +const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue) + +const changes = {} + +for (const fileStem of files) { + const currentJson = readCurrentJson(fileStem) + const beforeJson = readBaseJson(fileStem) || {} + const afterJson = currentJson || {} + const added = {} + const updated = {} + const deleted = [] + + for (const [key, value] of Object.entries(afterJson)) { + if (!(key in beforeJson)) { + added[key] = value + continue + } + + if (!compareJson(beforeJson[key], value)) { + updated[key] = { + before: beforeJson[key], + after: value, + } + } + } + + for (const key of Object.keys(beforeJson)) { + if (!(key in afterJson)) + deleted.push(key) + } + + changes[fileStem] = { + fileDeleted: currentJson === null, + added, + updated, + deleted, + } +} + +fs.writeFileSync( + outputPath, + JSON.stringify({ + baseSha, + headSha, + files, + changes, + }) +) diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index a813c87cec..e001f4d677 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -68,89 +68,7 @@ jobs: " web/i18n-config/languages.ts | sed 's/[[:space:]]*$//') generate_changes_json() { - node <<'NODE' - const { execFileSync } = require('node:child_process') - const fs = require('node:fs') - const path = require('node:path') - - const repoRoot = process.cwd() - const baseSha = process.env.BASE_SHA || '' - const headSha = process.env.HEAD_SHA || '' - const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean) - - const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`) - - const readCurrentJson = (fileStem) => { - const filePath = englishPath(fileStem) - if (!fs.existsSync(filePath)) - return null - - return JSON.parse(fs.readFileSync(filePath, 'utf8')) - } - - const readBaseJson = (fileStem) => { - if (!baseSha) - return null - - try { - const relativePath = `web/i18n/en-US/${fileStem}.json` - const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' }) - return JSON.parse(content) - } - catch (error) { - return null - } - } - - const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue) - - const changes = {} - - for (const fileStem of files) { - const currentJson = readCurrentJson(fileStem) - const beforeJson = readBaseJson(fileStem) || {} - const afterJson = currentJson || {} - const added = {} - const updated = {} - const deleted = [] - - for (const [key, value] of Object.entries(afterJson)) { - if (!(key in beforeJson)) { - added[key] = value - continue - } - - if (!compareJson(beforeJson[key], value)) { - updated[key] = { - before: beforeJson[key], - after: value, - } - } - } - - for (const key of Object.keys(beforeJson)) { - if (!(key in afterJson)) - deleted.push(key) - } - - changes[fileStem] = { - fileDeleted: currentJson === null, - added, - updated, - deleted, - } - } - - fs.writeFileSync( - '/tmp/i18n-changes.json', - JSON.stringify({ - baseSha, - headSha, - files, - changes, - }) - ) - NODE + node .github/scripts/generate-i18n-changes.mjs } if [ "${{ github.event_name }}" = "repository_dispatch" ]; then @@ -270,7 +188,7 @@ jobs: Tool rules: - Use Read for repository files. - Use Edit for JSON updates. - - Use Bash only for `pnpm`. + - Use Bash only for `vp`. - Do not use Bash for `git`, `gh`, or branch management. Required execution plan: @@ -292,7 +210,7 @@ jobs: - Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate. - If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth. 4. Run a scoped pre-check before editing: - - `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` + - `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` - Use this command as the source of truth for missing and extra keys inside the current scope. 5. Apply translations. - For every target language and scoped file: @@ -300,19 +218,19 @@ jobs: - If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed. - ADD missing keys. - UPDATE stale translations when the English value changed. - - DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope. + - DELETE removed keys. Prefer `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope. - Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names. - Match the existing terminology and register used by each locale. - Prefer one Edit per file when stable, but prioritize correctness over batching. 6. Verify only the edited files. - - Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- ` - - Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` + - Run `vp run dify-web#lint:fix --quiet -- ` + - Run `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` - If verification fails, fix the remaining problems before continuing. 7. Stop after the scoped locale files are updated and verification passes. - Do not create branches, commits, or pull requests. claude_args: | --max-turns 120 - --allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep" + --allowedTools "Read,Write,Edit,Bash(vp *),Bash(vp:*),Glob,Grep" - name: Prepare branch metadata id: pr_meta @@ -354,6 +272,7 @@ jobs: - name: Create or update translation PR if: steps.pr_meta.outputs.has_changes == 'true' env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }} FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }} TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }} @@ -402,8 +321,8 @@ jobs: '', '## Verification', '', - `- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``, - `- \`pnpm --dir web lint:fix --quiet -- \``, + `- \`vp run dify-web#i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``, + `- \`vp run dify-web#lint:fix --quiet -- \``, '', '## Notes', '', diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml index a1ca42b26e..9a11d3e8df 100644 --- a/.github/workflows/trigger-i18n-sync.yml +++ b/.github/workflows/trigger-i18n-sync.yml @@ -42,88 +42,7 @@ jobs: fi export BASE_SHA HEAD_SHA CHANGED_FILES - node <<'NODE' - const { execFileSync } = require('node:child_process') - const fs = require('node:fs') - const path = require('node:path') - - const repoRoot = process.cwd() - const baseSha = process.env.BASE_SHA || '' - const headSha = process.env.HEAD_SHA || '' - const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean) - - const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`) - - const readCurrentJson = (fileStem) => { - const filePath = englishPath(fileStem) - if (!fs.existsSync(filePath)) - return null - - return JSON.parse(fs.readFileSync(filePath, 'utf8')) - } - - const readBaseJson = (fileStem) => { - if (!baseSha) - return null - - try { - const relativePath = `web/i18n/en-US/${fileStem}.json` - const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' }) - return JSON.parse(content) - } - catch (error) { - return null - } - } - - const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue) - - const changes = {} - - for (const fileStem of files) { - const beforeJson = readBaseJson(fileStem) || {} - const afterJson = readCurrentJson(fileStem) || {} - const added = {} - const updated = {} - const deleted = [] - - for (const [key, value] of Object.entries(afterJson)) { - if (!(key in beforeJson)) { - added[key] = value - continue - } - - if (!compareJson(beforeJson[key], value)) { - updated[key] = { - before: beforeJson[key], - after: value, - } - } - } - - for (const key of Object.keys(beforeJson)) { - if (!(key in afterJson)) - deleted.push(key) - } - - changes[fileStem] = { - fileDeleted: readCurrentJson(fileStem) === null, - added, - updated, - deleted, - } - } - - fs.writeFileSync( - '/tmp/i18n-changes.json', - JSON.stringify({ - baseSha, - headSha, - files, - changes, - }) - ) - NODE + node .github/scripts/generate-i18n-changes.mjs if [ -n "$CHANGED_FILES" ]; then echo "has_changes=true" >> "$GITHUB_OUTPUT" diff --git a/.vite-hooks/pre-commit b/.vite-hooks/pre-commit index db5c504606..13bbd81cf6 100755 --- a/.vite-hooks/pre-commit +++ b/.vite-hooks/pre-commit @@ -81,8 +81,8 @@ if $web_modified; then if $web_ts_modified; then echo "Running TypeScript type-check:tsgo" - if ! pnpm run type-check:tsgo; then - echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors." + if ! npm run type-check:tsgo; then + echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors." exit 1 fi else @@ -90,8 +90,8 @@ if $web_modified; then fi echo "Running knip" - if ! pnpm run knip; then - echo "Knip check failed. Please run 'pnpm run knip' to fix the errors." + if ! npm run knip; then + echo "Knip check failed. Please run 'npm run knip' to fix the errors." exit 1 fi diff --git a/api/.env.example b/api/.env.example index 411456d2f3..485bc6fb3a 100644 --- a/api/.env.example +++ b/api/.env.example @@ -74,6 +74,13 @@ REDIS_USE_CLUSTERS=false REDIS_CLUSTERS= REDIS_CLUSTERS_PASSWORD= +REDIS_RETRY_RETRIES=3 +REDIS_RETRY_BACKOFF_BASE=1.0 +REDIS_RETRY_BACKOFF_CAP=10.0 +REDIS_SOCKET_TIMEOUT=5.0 +REDIS_SOCKET_CONNECT_TIMEOUT=5.0 +REDIS_HEALTH_CHECK_INTERVAL=30 + # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1 CELERY_BACKEND=redis diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 3b91207545..b49275758a 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -117,6 +117,37 @@ class RedisConfig(BaseSettings): default=None, ) + REDIS_RETRY_RETRIES: NonNegativeInt = Field( + description="Maximum number of retries per Redis command on " + "transient failures (ConnectionError, TimeoutError, socket.timeout)", + default=3, + ) + + REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field( + description="Base delay in seconds for exponential backoff between retries", + default=1.0, + ) + + REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field( + description="Maximum backoff delay in seconds between retries", + default=10.0, + ) + + REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field( + description="Socket timeout in seconds for Redis read/write operations", + default=5.0, + ) + + REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field( + description="Socket timeout in seconds for Redis connection establishment", + default=5.0, + ) + + REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field( + description="Interval in seconds between Redis connection health checks (0 to disable)", + default=30, + ) + @field_validator("REDIS_MAX_CONNECTIONS", mode="before") @classmethod def _empty_string_to_none_for_max_conns(cls, v): diff --git a/api/controllers/common/controller_schemas.py b/api/controllers/common/controller_schemas.py index e13bf025fc..39e3b5857d 100644 --- a/api/controllers/common/controller_schemas.py +++ b/api/controllers/common/controller_schemas.py @@ -48,11 +48,27 @@ class SavedMessageCreatePayload(BaseModel): # --- Workflow schemas --- +class DefaultBlockConfigQuery(BaseModel): + q: str | None = None + + +class WorkflowListQuery(BaseModel): + page: int = Field(default=1, ge=1, le=99999) + limit: int = Field(default=10, ge=1, le=100) + user_id: str | None = None + named_only: bool = False + + class WorkflowRunPayload(BaseModel): inputs: dict[str, Any] files: list[dict[str, Any]] | None = None +class WorkflowUpdatePayload(BaseModel): + marked_name: str | None = Field(default=None, max_length=20) + marked_comment: str | None = Field(default=None, max_length=100) + + # --- Audio schemas --- diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 16e1fa3245..06192936f1 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -92,11 +92,13 @@ class AppImportApi(Resource): EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") # Return appropriate status code based on result status = result.status - if status == ImportStatus.FAILED: - return result.model_dump(mode="json"), 400 - elif status == ImportStatus.PENDING: - return result.model_dump(mode="json"), 202 - return result.model_dump(mode="json"), 200 + match status: + case ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + case ImportStatus.PENDING: + return result.model_dump(mode="json"), 202 + case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS: + return result.model_dump(mode="json"), 200 @console_ns.route("/apps/imports//confirm") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index c4e163f68d..4638c8e90e 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services +from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload from controllers.console import console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.workflow_run import workflow_run_node_execution_model @@ -144,10 +145,6 @@ class PublishWorkflowPayload(BaseModel): marked_comment: str | None = Field(default=None, max_length=100) -class DefaultBlockConfigQuery(BaseModel): - q: str | None = None - - class ConvertToWorkflowPayload(BaseModel): name: str | None = None icon_type: str | None = None @@ -155,18 +152,6 @@ class ConvertToWorkflowPayload(BaseModel): icon_background: str | None = None -class WorkflowListQuery(BaseModel): - page: int = Field(default=1, ge=1, le=99999) - limit: int = Field(default=10, ge=1, le=100) - user_id: str | None = None - named_only: bool = False - - -class WorkflowUpdatePayload(BaseModel): - marked_name: str | None = Field(default=None, max_length=20) - marked_comment: str | None = Field(default=None, max_length=100) - - class WorkflowFeaturesPayload(BaseModel): features: dict[str, Any] = Field(..., description="Workflow feature configuration") diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 0621b880fe..bda2f9d5e0 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -403,24 +403,27 @@ class VariableApi(Resource): new_value = None if raw_value is not None: - if variable.value_type == SegmentType.FILE: - if not isinstance(raw_value, dict): - raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping( - mapping=raw_value, - tenant_id=app_model.tenant_id, - access_controller=_file_access_controller, - ) - elif variable.value_type == SegmentType.ARRAY_FILE: - if not isinstance(raw_value, list): - raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") - if len(raw_value) > 0 and not isinstance(raw_value[0], dict): - raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings( - mappings=raw_value, - tenant_id=app_model.tenant_id, - access_controller=_file_access_controller, - ) + match variable.value_type: + case SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) + case SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) + case _: + pass new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 93feec0019..3549f9542d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -223,24 +223,27 @@ class RagPipelineVariableApi(Resource): new_value = None if raw_value is not None: - if variable.value_type == SegmentType.FILE: - if not isinstance(raw_value, dict): - raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping( - mapping=raw_value, - tenant_id=pipeline.tenant_id, - access_controller=_file_access_controller, - ) - elif variable.value_type == SegmentType.ARRAY_FILE: - if not isinstance(raw_value, list): - raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") - if len(raw_value) > 0 and not isinstance(raw_value[0], dict): - raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings( - mappings=raw_value, - tenant_id=pipeline.tenant_id, - access_controller=_file_access_controller, - ) + match variable.value_type: + case SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) + case SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) + case _: + pass new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index 732a6dc446..76a8c136e4 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -83,11 +83,13 @@ class RagPipelineImportApi(Resource): # Return appropriate status code based on result status = result.status - if status == ImportStatus.FAILED: - return result.model_dump(mode="json"), 400 - elif status == ImportStatus.PENDING: - return result.model_dump(mode="json"), 202 - return result.model_dump(mode="json"), 200 + match status: + case ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + case ImportStatus.PENDING: + return result.model_dump(mode="json"), 202 + case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS: + return result.model_dump(mode="json"), 200 @console_ns.route("/rag/pipelines/imports//confirm") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 70dfe47d7f..6c02646c22 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services +from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( @@ -94,22 +95,6 @@ class PublishedWorkflowRunPayload(DraftWorkflowRunPayload): original_document_id: str | None = None -class DefaultBlockConfigQuery(BaseModel): - q: str | None = None - - -class WorkflowListQuery(BaseModel): - page: int = Field(default=1, ge=1, le=99999) - limit: int = Field(default=10, ge=1, le=100) - user_id: str | None = None - named_only: bool = False - - -class WorkflowUpdatePayload(BaseModel): - marked_name: str | None = Field(default=None, max_length=20) - marked_comment: str | None = Field(default=None, max_length=100) - - class NodeIdQuery(BaseModel): node_id: str diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index e37e78c966..5d79e1b5e9 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -168,12 +168,13 @@ class ConsoleWorkflowEventsApi(Resource): else: msg_generator = MessageGenerator() generator: BaseAppGenerator - if app.mode == AppMode.ADVANCED_CHAT: - generator = AdvancedChatAppGenerator() - elif app.mode == AppMode.WORKFLOW: - generator = WorkflowAppGenerator() - else: - raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") + match app.mode: + case AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + case AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + case _: + raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 6a2e0b65fb..66082893b8 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -138,12 +138,15 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() - if auth_type == WebAppAuthType.PUBLIC: - return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) - elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": - raise WebAppAuthRequiredError("Please login as external user.") - elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": - raise WebAppAuthRequiredError("Please login as internal user.") + match auth_type: + case WebAppAuthType.PUBLIC: + return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) + case WebAppAuthType.EXTERNAL: + if user_auth_type != "external": + raise WebAppAuthRequiredError("Please login as external user.") + case WebAppAuthType.INTERNAL: + if user_auth_type != "internal": + raise WebAppAuthRequiredError("Please login as internal user.") end_user = None if end_user_id: diff --git a/api/controllers/web/workflow_events.py b/api/controllers/web/workflow_events.py index 61568e70e6..474f9c0957 100644 --- a/api/controllers/web/workflow_events.py +++ b/api/controllers/web/workflow_events.py @@ -72,12 +72,13 @@ class WorkflowEventsApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) msg_generator = MessageGenerator() generator: BaseAppGenerator - if app_mode == AppMode.ADVANCED_CHAT: - generator = AdvancedChatAppGenerator() - elif app_mode == AppMode.WORKFLOW: - generator = WorkflowAppGenerator() - else: - raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") + match app_mode: + case AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + case AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + case _: + raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index a884a1c7f9..7b4cb98bd4 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -10,7 +10,7 @@ from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variable_loader import VariableLoader from graphon.variables.variables import Variable from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager @@ -363,7 +363,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): :return: List of conversation variables ready for use """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: existing_variables = self._load_existing_conversation_variables(session) if not existing_variables: @@ -376,7 +376,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # Convert to Variable objects for use in the workflow conversation_variables = [var.to_variable() for var in existing_variables] - session.commit() return cast(list[Variable], conversation_variables) def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 5203de225c..0ce9ddce9e 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -16,7 +16,7 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.nodes import BuiltinNodeTypes from graphon.runtime import GraphRuntimeState from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -328,13 +328,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): @contextmanager def _database_session(self): """Context manager for database sessions.""" - with Session(db.engine, expire_on_commit=False) as session: - try: - yield session - session.commit() - except Exception: - session.rollback() - raise + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: + yield session def _ensure_workflow_initialized(self): """Fluent validation for workflow state.""" diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 49af169e88..f1b8b08eaa 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -7,7 +7,7 @@ from typing import Union from graphon.entities import WorkflowStartReason from graphon.enums import WorkflowExecutionStatus from graphon.runtime import GraphRuntimeState -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager @@ -252,13 +252,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): @contextmanager def _database_session(self): """Context manager for database sessions.""" - with Session(db.engine, expire_on_commit=False) as session: - try: - yield session - session.commit() - except Exception: - session.rollback() - raise + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: + yield session def _ensure_workflow_initialized(self): """Fluent validation for workflow state.""" diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 182f1b767d..0bb10190c4 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,6 +1,6 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from configs import dify_config from core.entities.model_entities import ModelStatus @@ -57,37 +57,37 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL used_quota = 1 if used_quota is not None and system_configuration.current_quota_type is not None: - if system_configuration.current_quota_type == ProviderQuotaType.TRIAL: - from services.credit_pool_service import CreditPoolService + match system_configuration.current_quota_type: + case ProviderQuotaType.TRIAL: + from services.credit_pool_service import CreditPoolService - CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=used_quota, - ) - elif system_configuration.current_quota_type == ProviderQuotaType.PAID: - from services.credit_pool_service import CreditPoolService - - CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=used_quota, - pool_type="paid", - ) - else: - with Session(db.engine) as session: - stmt = ( - update(Provider) - .where( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type, - Provider.quota_limit > Provider.quota_used, - ) - .values( - quota_used=Provider.quota_used + used_quota, - last_used=naive_utc_now(), - ) + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, ) - session.execute(stmt) - session.commit() + case ProviderQuotaType.PAID: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + pool_type="paid", + ) + case ProviderQuotaType.FREE: + with sessionmaker(bind=db.engine).begin() as session: + stmt = ( + update(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type, + Provider.quota_limit > Provider.quota_used, + ) + .values( + quota_used=Provider.quota_used + used_quota, + last_used=naive_utc_now(), + ) + ) + session.execute(stmt) diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 9df78a7830..6bb177fe02 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -266,9 +266,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): event = message.event if isinstance(event, QueueErrorEvent): - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: err = self.handle_error(event=event, session=session, message_id=self._message_id) - session.commit() yield self.error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): @@ -288,10 +287,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): answer=output_moderation_answer ) - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # Save message self._save_message(session=session, trace_manager=trace_manager) - session.commit() message_end_resp = self._message_end_to_stream_response() yield message_end_resp elif isinstance(event, QueueRetrieverResourcesEvent): diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index b23a33923b..77310baf74 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -40,41 +40,44 @@ def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, Upl size = 0 extension = "" - if message_file.transfer_method == FileTransferMethod.REMOTE_URL: - url = message_file.url - if message_file.url: - filename = message_file.url.split("/")[-1].split("?")[0] - if "." in filename: - extension = "." + filename.rsplit(".", 1)[1] - elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE: - if upload_file: - url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) - filename = upload_file.name - mime_type = upload_file.mime_type or "application/octet-stream" - size = upload_file.size or 0 - extension = f".{upload_file.extension}" if upload_file.extension else "" - elif message_file.upload_file_id: - url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) - elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url: - if message_file.url.startswith(("http://", "https://")): + match message_file.transfer_method: + case FileTransferMethod.REMOTE_URL: url = message_file.url - filename = message_file.url.split("/")[-1].split("?")[0] - if "." in filename: - extension = "." + filename.rsplit(".", 1)[1] - else: - url_parts = message_file.url.split("/") - if url_parts: - file_part = url_parts[-1].split("?")[0] - if "." in file_part: - tool_file_id, ext = file_part.rsplit(".", 1) - extension = f".{ext}" - if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH: + if message_file.url: + filename = message_file.url.split("/")[-1].split("?")[0] + if "." in filename: + extension = "." + filename.rsplit(".", 1)[1] + case FileTransferMethod.LOCAL_FILE: + if upload_file: + url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) + filename = upload_file.name + mime_type = upload_file.mime_type or "application/octet-stream" + size = upload_file.size or 0 + extension = f".{upload_file.extension}" if upload_file.extension else "" + elif message_file.upload_file_id: + url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) + case FileTransferMethod.TOOL_FILE if message_file.url: + if message_file.url.startswith(("http://", "https://")): + url = message_file.url + filename = message_file.url.split("/")[-1].split("?")[0] + if "." in filename: + extension = "." + filename.rsplit(".", 1)[1] + else: + url_parts = message_file.url.split("/") + if url_parts: + file_part = url_parts[-1].split("?")[0] + if "." in file_part: + tool_file_id, ext = file_part.rsplit(".", 1) + extension = f".{ext}" + if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH: + extension = ".bin" + else: + tool_file_id = file_part extension = ".bin" - else: - tool_file_id = file_part - extension = ".bin" - url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) - filename = file_part + url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) + filename = file_part + case FileTransferMethod.TOOL_FILE | FileTransferMethod.DATASOURCE_FILE: + pass transfer_method_value = message_file.transfer_method.value remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else "" diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 8de002ae55..72171d1536 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -187,15 +187,16 @@ def build_parameter_schema( def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict: """Prepare arguments based on app mode""" - if app.mode == AppMode.WORKFLOW: - return {"inputs": arguments} - elif app.mode == AppMode.COMPLETION: - return {"query": "", "inputs": arguments} - else: - # Chat modes - create a copy to avoid modifying original dict - args_copy = arguments.copy() - query = args_copy.pop("query", "") - return {"query": query, "inputs": args_copy} + match app.mode: + case AppMode.WORKFLOW: + return {"inputs": arguments} + case AppMode.COMPLETION: + return {"query": "", "inputs": arguments} + case _: + # Chat modes - create a copy to avoid modifying original dict + args_copy = arguments.copy() + query = args_copy.pop("query", "") + return {"query": query, "inputs": args_copy} def extract_answer_from_response(app: App, response: Any) -> str: @@ -229,17 +230,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str: def process_mapping_response(app: App, response: Mapping) -> str: """Process mapping response based on app mode""" - if app.mode in { - AppMode.ADVANCED_CHAT, - AppMode.COMPLETION, - AppMode.CHAT, - AppMode.AGENT_CHAT, - }: - return response.get("answer", "") - elif app.mode == AppMode.WORKFLOW: - return json.dumps(response["data"]["outputs"], ensure_ascii=False) - else: - raise ValueError("Invalid app mode: " + str(app.mode)) + match app.mode: + case AppMode.ADVANCED_CHAT | AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT: + return response.get("answer", "") + case AppMode.WORKFLOW: + return json.dumps(response["data"]["outputs"], ensure_ascii=False) + case _: + raise ValueError("Invalid app mode: " + str(app.mode)) def convert_input_form_to_parameters( diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index be11d2223c..e2d2be92cb 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -72,17 +72,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): conversation_id = conversation_id or "" - if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}: - if not query: - raise ValueError("missing query") + match app.mode: + case AppMode.ADVANCED_CHAT | AppMode.AGENT_CHAT | AppMode.CHAT: + if not query: + raise ValueError("missing query") - return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files) - elif app.mode == AppMode.WORKFLOW: - return cls.invoke_workflow_app(app, user, stream, inputs, files) - elif app.mode == AppMode.COMPLETION: - return cls.invoke_completion_app(app, user, stream, inputs, files) - - raise ValueError("unexpected app type") + return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files) + case AppMode.WORKFLOW: + return cls.invoke_workflow_app(app, user, stream, inputs, files) + case AppMode.COMPLETION: + return cls.invoke_completion_app(app, user, stream, inputs, files) + case _: + raise ValueError("unexpected app type") @classmethod def invoke_chat_app( @@ -98,60 +99,61 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ invoke chat app """ - if app.mode == AppMode.ADVANCED_CHAT: - workflow = app.workflow - if not workflow: + match app.mode: + case AppMode.ADVANCED_CHAT: + workflow = app.workflow + if not workflow: + raise ValueError("unexpected app type") + + pause_config = PauseStateLayerConfig( + session_factory=db.engine, + state_owner_user_id=workflow.created_by, + ) + + return AdvancedChatAppGenerator().generate( + app_model=app, + workflow=workflow, + user=user, + args={ + "inputs": inputs, + "query": query, + "files": files, + "conversation_id": conversation_id, + }, + invoke_from=InvokeFrom.SERVICE_API, + workflow_run_id=str(uuid.uuid4()), + streaming=stream, + pause_state_config=pause_config, + ) + case AppMode.AGENT_CHAT: + return AgentChatAppGenerator().generate( + app_model=app, + user=user, + args={ + "inputs": inputs, + "query": query, + "files": files, + "conversation_id": conversation_id, + }, + invoke_from=InvokeFrom.SERVICE_API, + streaming=stream, + ) + case AppMode.CHAT: + return ChatAppGenerator().generate( + app_model=app, + user=user, + args={ + "inputs": inputs, + "query": query, + "files": files, + "conversation_id": conversation_id, + }, + invoke_from=InvokeFrom.SERVICE_API, + streaming=stream, + ) + case _: raise ValueError("unexpected app type") - pause_config = PauseStateLayerConfig( - session_factory=db.engine, - state_owner_user_id=workflow.created_by, - ) - - return AdvancedChatAppGenerator().generate( - app_model=app, - workflow=workflow, - user=user, - args={ - "inputs": inputs, - "query": query, - "files": files, - "conversation_id": conversation_id, - }, - invoke_from=InvokeFrom.SERVICE_API, - workflow_run_id=str(uuid.uuid4()), - streaming=stream, - pause_state_config=pause_config, - ) - elif app.mode == AppMode.AGENT_CHAT: - return AgentChatAppGenerator().generate( - app_model=app, - user=user, - args={ - "inputs": inputs, - "query": query, - "files": files, - "conversation_id": conversation_id, - }, - invoke_from=InvokeFrom.SERVICE_API, - streaming=stream, - ) - elif app.mode == AppMode.CHAT: - return ChatAppGenerator().generate( - app_model=app, - user=user, - args={ - "inputs": inputs, - "query": query, - "files": files, - "conversation_id": conversation_id, - }, - invoke_from=InvokeFrom.SERVICE_API, - streaming=stream, - ) - else: - raise ValueError("unexpected app type") - @classmethod def invoke_workflow_app( cls, diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index ec4858ae2e..c75c30a98a 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -209,7 +209,10 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/decode/from_identifier", PluginDecodeResponse, - params={"plugin_unique_identifier": plugin_unique_identifier}, + params={ + "plugin_unique_identifier": plugin_unique_identifier, + "PluginUniqueIdentifier": plugin_unique_identifier, # compat with daemon <= 0.5.4 + }, ) def fetch_plugin_installation_by_ids( diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 552de66f8b..e3b3f83c20 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -961,36 +961,37 @@ class ProviderManager: raise ValueError("quota_used is None") if provider_record.quota_limit is None: raise ValueError("quota_limit is None") - if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None: - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=trail_pool.quota_used, - quota_limit=trail_pool.quota_limit, - is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + match provider_quota.quota_type: + case ProviderQuotaType.TRIAL if trail_pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=trail_pool.quota_used, + quota_limit=trail_pool.quota_limit, + is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) - elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None: - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=paid_pool.quota_used, - quota_limit=paid_pool.quota_limit, - is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + case ProviderQuotaType.PAID if paid_pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=paid_pool.quota_used, + quota_limit=paid_pool.quota_limit, + is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) - else: - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=provider_record.quota_used, - quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used - or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + case _: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=provider_record.quota_used, + quota_limit=provider_record.quota_limit, + is_valid=provider_record.quota_limit > provider_record.quota_used + or provider_record.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) quota_configurations.append(quota_configuration) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index ddb549ba9d..79cc5f0344 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -37,11 +37,12 @@ class AnalyticdbVector(BaseVector): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) - self.analyticdb_vector._create_collection_if_not_exists(dimension) + self.analyticdb_vector.create_collection_if_not_exists(dimension) self.analyticdb_vector.add_texts(texts, embeddings) - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: self.analyticdb_vector.add_texts(documents, embeddings) + return [] def text_exists(self, id: str) -> bool: return self.analyticdb_vector.text_exists(id) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index fb6eaa370a..726ee8c050 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -123,7 +123,7 @@ class AnalyticdbVectorOpenAPI: else: raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") - def _create_collection_if_not_exists(self, embedding_dimension: int): + def create_collection_if_not_exists(self, embedding_dimension: int): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index 12126f32d6..41c33a3ab1 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -1,5 +1,6 @@ import json import uuid +from collections.abc import Iterator from contextlib import contextmanager from typing import Any @@ -74,7 +75,7 @@ class AnalyticdbVectorBySql: ) @contextmanager - def _get_cursor(self): + def _get_cursor(self) -> Iterator[Any]: assert self.pool is not None, "Connection pool is not initialized" conn = self.pool.getconn() cur = conn.cursor() @@ -130,7 +131,7 @@ class AnalyticdbVectorBySql: ) cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}") - def _create_collection_if_not_exists(self, embedding_dimension: int): + def create_collection_if_not_exists(self, embedding_dimension: int): cache_key = f"vector_indexing_{self._collection_name}" lock_name = f"{cache_key}_lock" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 73787c2f00..5b0cfbea15 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -2,7 +2,7 @@ import json from typing import Any, TypedDict import chromadb -from chromadb import QueryResult, Settings +from chromadb import QueryResult, Settings # pyright: ignore[reportPrivateImportUsage] from pydantic import BaseModel from configs import dify_config @@ -106,14 +106,15 @@ class ChromaVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) document_ids_filter = kwargs.get("document_ids_filter") + results: QueryResult if document_ids_filter: - results: QueryResult = collection.query( + results = collection.query( query_embeddings=query_vector, n_results=kwargs.get("top_k", 4), where={"document_id": {"$in": document_ids_filter}}, # type: ignore ) else: - results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore + results = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore score_threshold = float(kwargs.get("score_threshold") or 0.0) # Check if results contain data @@ -165,8 +166,8 @@ class ChromaVectorFactory(AbstractVectorFactory): config=ChromaConfig( host=dify_config.CHROMA_HOST or "", port=dify_config.CHROMA_PORT, - tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, - database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, + tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, # pyright: ignore[reportPrivateImportUsage] + database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, # pyright: ignore[reportPrivateImportUsage] auth_provider=dify_config.CHROMA_AUTH_PROVIDER, auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS, ), diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index f4fcb975c3..b5ff87fc5d 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import qdrant_client from flask import current_app @@ -32,7 +32,6 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset, DatasetCollectionBinding if TYPE_CHECKING: - from qdrant_client import grpc # noqa from qdrant_client.conversions import common_types from qdrant_client.http import models as rest @@ -180,7 +179,7 @@ class QdrantVector(BaseVector): for batch_ids, points in self._generate_rest_batches( texts, embeddings, filtered_metadatas, uuids, 64, self._group_id ): - self._client.upsert(collection_name=self._collection_name, points=points) + self._client.upsert(collection_name=self._collection_name, points=cast("common_types.Points", points)) added_ids.extend(batch_ids) return added_ids @@ -472,7 +471,7 @@ class QdrantVector(BaseVector): def _reload_if_needed(self): if isinstance(self._client, QdrantLocal): - self._client._load() + self._client._load() # pyright: ignore[reportPrivateUsage] @classmethod def _document_from_scored_point( diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index e486375ec2..3ecc9867fa 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -26,7 +26,7 @@ from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) -Base = declarative_base() # type: Any +Base: Any = declarative_base() class RelytConfig(BaseModel): diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 7dd8beaa46..f9fbfbc409 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -19,12 +19,15 @@ class UnstructuredWordExtractor(BaseExtractor): def extract(self) -> list[Document]: from unstructured.__version__ import __version__ as __unstructured_version__ - from unstructured.file_utils.filetype import FileType, detect_filetype + from unstructured.file_utils.filetype import ( # pyright: ignore[reportPrivateImportUsage] + FileType, + detect_filetype, + ) unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) # check the file extension try: - import magic # noqa: F401 + import magic # noqa: F401 # pyright: ignore[reportUnusedImport] is_doc = detect_filetype(self._file_path) == FileType.DOC except ImportError: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 4e9b53b83e..0f3351fd68 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -15,7 +15,7 @@ from graphon.model_runtime.entities.message_entities import PromptMessage, Promp from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import and_, func, literal, or_, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.app.app_config.entities import ( DatasetEntity, @@ -884,7 +884,7 @@ class DatasetRetrieval: self._send_trace_task(message_id, documents, timer) return - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # Collect all document_ids and batch fetch DatasetDocuments document_ids = { doc.metadata["document_id"] @@ -975,7 +975,6 @@ class DatasetRetrieval: {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False, ) - session.commit() self._send_trace_task(message_id, documents, timer) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d45d45c520..2593e381cf 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -205,16 +205,160 @@ class ToolManager: :return: the tool """ - if provider_type == ToolProviderType.BUILT_IN: - # check if the builtin tool need credentials - provider_controller = cls.get_builtin_provider(provider_id, tenant_id) + match provider_type: + case ToolProviderType.BUILT_IN: + provider_controller = cls.get_builtin_provider(provider_id, tenant_id) - builtin_tool = provider_controller.get_tool(tool_name) - if not builtin_tool: - raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") + builtin_tool = provider_controller.get_tool(tool_name) + if not builtin_tool: + raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") + + if not provider_controller.need_credentials: + return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + user_id=user_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ) + builtin_provider = None + if isinstance(provider_controller, PluginToolProviderController): + provider_id_entity = ToolProviderID(provider_id) + if is_valid_uuid(credential_id): + try: + builtin_provider_stmt = select(BuiltinToolProvider).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + builtin_provider = db.session.scalar(builtin_provider_stmt) + except Exception as e: + builtin_provider = None + logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") + + if builtin_provider is None: + with Session(db.engine) as session: + builtin_provider = session.scalar( + sa.select(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == str(provider_id_entity)) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + ) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"no default provider for {provider_id}") + else: + builtin_provider = db.session.scalar( + select(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id) + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .limit(1) + ) + + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=builtin_provider.id, + provider=provider_id, + credential_type=PluginCredentialType.TOOL, + check_existence=False, + ) + + encrypter, cache = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type) + ], + cache=ToolProviderCredentialsCache( + tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id + ), + ) + + decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) + + if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): + # TODO: circular import + from core.plugin.impl.oauth import OAuthHandler + from services.tools.builtin_tools_manage_service import BuiltinToolManageService + + tool_provider = ToolProviderID(provider_id) + provider_name = tool_provider.provider_name + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" + system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) + + oauth_handler = OAuthHandler() + refreshed_credentials = oauth_handler.refresh_credentials( + tenant_id=tenant_id, + user_id=builtin_provider.user_id, + plugin_id=tool_provider.plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=system_credentials or {}, + credentials=decrypted_credentials, + ) + # update the credentials + builtin_provider.encrypted_credentials = json.dumps( + encrypter.encrypt(refreshed_credentials.credentials) + ) + builtin_provider.expires_at = refreshed_credentials.expires_at + db.session.commit() + decrypted_credentials = refreshed_credentials.credentials + cache.delete() - if not provider_controller.need_credentials: return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + user_id=user_id, + credentials=dict(decrypted_credentials), + credential_type=builtin_provider.credential_type, + runtime_parameters={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ) + + case ToolProviderType.API: + api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) + encrypter, _ = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=api_provider, + ) + return api_provider.get_tool(tool_name).fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + user_id=user_id, + credentials=dict(encrypter.decrypt(credentials)), + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ) + case ToolProviderType.WORKFLOW: + workflow_provider_stmt = select(WorkflowToolProvider).where( + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id + ) + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + workflow_provider = session.scalar(workflow_provider_stmt) + + if workflow_provider is None: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) + if controller_tools is None or len(controller_tools) == 0: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, user_id=user_id, @@ -223,177 +367,28 @@ class ToolManager: tool_invoke_from=tool_invoke_from, ) ) - builtin_provider = None - if isinstance(provider_controller, PluginToolProviderController): - provider_id_entity = ToolProviderID(provider_id) - # get specific credentials - if is_valid_uuid(credential_id): - try: - builtin_provider_stmt = select(BuiltinToolProvider).where( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.id == credential_id, - ) - builtin_provider = db.session.scalar(builtin_provider_stmt) - except Exception as e: - builtin_provider = None - logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) - # if the provider has been deleted, raise an error - if builtin_provider is None: - raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") - - # fallback to the default provider - if builtin_provider is None: - # use the default provider - with Session(db.engine) as session: - builtin_provider = session.scalar( - sa.select(BuiltinToolProvider) - .where( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == str(provider_id_entity)) - | (BuiltinToolProvider.provider == provider_id_entity.provider_name), - ) - .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - ) - if builtin_provider is None: - raise ToolProviderNotFoundError(f"no default provider for {provider_id}") - else: - builtin_provider = db.session.scalar( - select(BuiltinToolProvider) - .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) - .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .limit(1) - ) - - if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - - # check if the credential is allowed to be used - from core.helper.credential_utils import check_credential_policy_compliance - - check_credential_policy_compliance( - credential_id=builtin_provider.id, - provider=provider_id, - credential_type=PluginCredentialType.TOOL, - check_existence=False, - ) - - encrypter, cache = create_provider_encrypter( - tenant_id=tenant_id, - config=[ - x.to_basic_provider_config() - for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type) - ], - cache=ToolProviderCredentialsCache( - tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id - ), - ) - - # decrypt the credentials - decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) - - # check if the credentials is expired - if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): - # TODO: circular import - from core.plugin.impl.oauth import OAuthHandler - from services.tools.builtin_tools_manage_service import BuiltinToolManageService - - # refresh the credentials - tool_provider = ToolProviderID(provider_id) - provider_name = tool_provider.provider_name - redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" - system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) - - oauth_handler = OAuthHandler() - # refresh the credentials - refreshed_credentials = oauth_handler.refresh_credentials( - tenant_id=tenant_id, - user_id=builtin_provider.user_id, - plugin_id=tool_provider.plugin_id, - provider=provider_name, - redirect_uri=redirect_uri, - system_credentials=system_credentials or {}, - credentials=decrypted_credentials, - ) - # update the credentials - builtin_provider.encrypted_credentials = json.dumps( - encrypter.encrypt(refreshed_credentials.credentials) - ) - builtin_provider.expires_at = refreshed_credentials.expires_at - db.session.commit() - decrypted_credentials = refreshed_credentials.credentials - cache.delete() - - return builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - user_id=user_id, - credentials=dict(decrypted_credentials), - credential_type=builtin_provider.credential_type, - runtime_parameters={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ) - - elif provider_type == ToolProviderType.API: - api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - encrypter, _ = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=api_provider, - ) - return api_provider.get_tool(tool_name).fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - user_id=user_id, - credentials=dict(encrypter.decrypt(credentials)), - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ) - elif provider_type == ToolProviderType.WORKFLOW: - workflow_provider_stmt = select(WorkflowToolProvider).where( - WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id - ) - with Session(db.engine, expire_on_commit=False) as session, session.begin(): - workflow_provider = session.scalar(workflow_provider_stmt) - - if workflow_provider is None: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) - controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) - if controller_tools is None or len(controller_tools) == 0: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - user_id=user_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ) - elif provider_type == ToolProviderType.APP: - raise NotImplementedError("app provider not implemented") - elif provider_type == ToolProviderType.PLUGIN: - plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) - runtime = getattr(plugin_tool, "runtime", None) - if runtime is not None: - runtime.user_id = user_id - runtime.invoke_from = invoke_from - runtime.tool_invoke_from = tool_invoke_from - return plugin_tool - elif provider_type == ToolProviderType.MCP: - mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) - runtime = getattr(mcp_tool, "runtime", None) - if runtime is not None: - runtime.user_id = user_id - runtime.invoke_from = invoke_from - runtime.tool_invoke_from = tool_invoke_from - return mcp_tool - else: - raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + case ToolProviderType.APP: + raise NotImplementedError("app provider not implemented") + case ToolProviderType.PLUGIN: + plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + runtime = getattr(plugin_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return plugin_tool + case ToolProviderType.MCP: + mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + runtime = getattr(mcp_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return mcp_tool + case ToolProviderType.DATASET_RETRIEVAL: + raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + case _: + raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod def get_agent_tool_runtime( @@ -1027,31 +1022,31 @@ class ToolManager: :param provider_id: the id of the provider :return: """ - provider_type = provider_type - provider_id = provider_id - if provider_type == ToolProviderType.BUILT_IN: - provider = ToolManager.get_builtin_provider(provider_id, tenant_id) - if isinstance(provider, PluginToolProviderController): + match provider_type: + case ToolProviderType.BUILT_IN: + provider = ToolManager.get_builtin_provider(provider_id, tenant_id) + if isinstance(provider, PluginToolProviderController): + try: + return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + return cls.generate_builtin_tool_icon_url(provider_id) + case ToolProviderType.API: + return cls.generate_api_tool_icon_url(tenant_id, provider_id) + case ToolProviderType.WORKFLOW: + return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) + case ToolProviderType.PLUGIN: + provider = ToolManager.get_plugin_provider(provider_id, tenant_id) try: return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} - return cls.generate_builtin_tool_icon_url(provider_id) - elif provider_type == ToolProviderType.API: - return cls.generate_api_tool_icon_url(tenant_id, provider_id) - elif provider_type == ToolProviderType.WORKFLOW: - return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) - elif provider_type == ToolProviderType.PLUGIN: - provider = ToolManager.get_plugin_provider(provider_id, tenant_id) - try: - return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} - raise ValueError(f"plugin provider {provider_id} not found") - elif provider_type == ToolProviderType.MCP: - return cls.generate_mcp_tool_icon_url(tenant_id, provider_id) - else: - raise ValueError(f"provider type {provider_type} not found") + case ToolProviderType.MCP: + return cls.generate_mcp_tool_icon_url(tenant_id, provider_id) + case ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL: + raise ValueError(f"provider type {provider_type} not found") + case _: + raise ValueError(f"provider type {provider_type} not found") @classmethod def _convert_tool_parameters_type( diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index a3fb4eda92..a17b7f108d 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -305,14 +305,15 @@ class WorkflowTool(Tool): "transfer_method": file.transfer_method.value, "type": file.type.value, } - if file.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = resolve_file_record_id(file.reference) - elif file.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = resolve_file_record_id(file.reference) - elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE: - file_dict["datasource_file_id"] = resolve_file_record_id(file.reference) - elif file.transfer_method == FileTransferMethod.REMOTE_URL: - file_dict["url"] = file.generate_url() + match file.transfer_method: + case FileTransferMethod.TOOL_FILE: + file_dict["tool_file_id"] = resolve_file_record_id(file.reference) + case FileTransferMethod.LOCAL_FILE: + file_dict["upload_file_id"] = resolve_file_record_id(file.reference) + case FileTransferMethod.DATASOURCE_FILE: + file_dict["datasource_file_id"] = resolve_file_record_id(file.reference) + case FileTransferMethod.REMOTE_URL: + file_dict["url"] = file.generate_url() files.append(file_dict) except Exception: @@ -357,8 +358,11 @@ class WorkflowTool(Tool): def _update_file_mapping(self, file_dict: dict): file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) - if transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file_id - elif transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file_id + match transfer_method: + case FileTransferMethod.TOOL_FILE: + file_dict["tool_file_id"] = file_id + case FileTransferMethod.LOCAL_FILE: + file_dict["upload_file_id"] = file_id + case FileTransferMethod.REMOTE_URL | FileTransferMethod.DATASOURCE_FILE: + pass return file_dict diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 6ff162973c..f8e239d250 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -4,7 +4,7 @@ from graphon.entities.base_node_data import BaseNodeData from graphon.enums import NodeType from pydantic import BaseModel -from core.rag.entities import WeightedScoreConfig +from core.rag.entities.retrieval_settings import WeightedScoreConfig from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index ebaac93934..6a0d633627 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -155,24 +155,25 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = raw_data continue - if param_type == SegmentType.FILE: - # Get File object (already processed by webhook controller) - files = webhook_data.get("files", {}) - if files and isinstance(files, dict): - file = files.get(param_name) - if file and isinstance(file, dict): - file_var = self.generate_file_var(param_name, file) - if file_var: - outputs[param_name] = file_var + match param_type: + case SegmentType.FILE: + # Get File object (already processed by webhook controller) + files = webhook_data.get("files", {}) + if files and isinstance(files, dict): + file = files.get(param_name) + if file and isinstance(file, dict): + file_var = self.generate_file_var(param_name, file) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = files else: outputs[param_name] = files else: outputs[param_name] = files - else: - outputs[param_name] = files - else: - # Get regular body parameter - outputs[param_name] = webhook_data.get("body", {}).get(param_name) + case _: + # Get regular body parameter + outputs[param_name] = webhook_data.get("body", {}).get(param_name) # Include raw webhook data for debugging/advanced use outputs["_webhook_raw"] = webhook_data diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 5f528dbf9e..b9e592cadb 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -7,10 +7,12 @@ from typing import TYPE_CHECKING, Any, Union import redis from redis import RedisError +from redis.backoff import ExponentialWithJitterBackoff # type: ignore from redis.cache import CacheConfig from redis.client import PubSub from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection +from redis.retry import Retry from redis.sentinel import Sentinel from configs import dify_config @@ -158,8 +160,41 @@ def _get_cache_configuration() -> CacheConfig | None: return CacheConfig() +def _get_retry_policy() -> Retry: + """Build the shared retry policy for Redis connections.""" + return Retry( + backoff=ExponentialWithJitterBackoff( + base=dify_config.REDIS_RETRY_BACKOFF_BASE, + cap=dify_config.REDIS_RETRY_BACKOFF_CAP, + ), + retries=dify_config.REDIS_RETRY_RETRIES, + ) + + +def _get_connection_health_params() -> dict[str, Any]: + """Get connection health and retry parameters for standalone and Sentinel Redis clients.""" + return { + "retry": _get_retry_policy(), + "socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT, + "socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, + "health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL, + } + + +def _get_cluster_connection_health_params() -> dict[str, Any]: + """Get retry and timeout parameters for Redis Cluster clients. + + RedisCluster does not support ``health_check_interval`` as a constructor + keyword (it is silently stripped by ``cleanup_kwargs``), so it is excluded + here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout`` + are passed through. + """ + params = _get_connection_health_params() + return {k: v for k, v in params.items() if k != "health_check_interval"} + + def _get_base_redis_params() -> dict[str, Any]: - """Get base Redis connection parameters.""" + """Get base Redis connection parameters including retry and health policy.""" return { "username": dify_config.REDIS_USERNAME, "password": dify_config.REDIS_PASSWORD or None, @@ -169,6 +204,7 @@ def _get_base_redis_params() -> dict[str, Any]: "decode_responses": False, "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, "cache_config": _get_cache_configuration(), + **_get_connection_health_params(), } @@ -215,6 +251,7 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: "password": dify_config.REDIS_CLUSTERS_PASSWORD, "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, "cache_config": _get_cache_configuration(), + **_get_cluster_connection_health_params(), } if dify_config.REDIS_MAX_CONNECTIONS: cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS @@ -226,7 +263,8 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis """Create standalone Redis client.""" connection_class, ssl_kwargs = _get_ssl_configuration() - redis_params.update( + params = {**redis_params} + params.update( { "host": dify_config.REDIS_HOST, "port": dify_config.REDIS_PORT, @@ -235,28 +273,31 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis ) if dify_config.REDIS_MAX_CONNECTIONS: - redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS if ssl_kwargs: - redis_params.update(ssl_kwargs) + params.update(ssl_kwargs) - pool = redis.ConnectionPool(**redis_params) + pool = redis.ConnectionPool(**params) client: redis.Redis = redis.Redis(connection_pool=pool) return client def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster: max_conns = dify_config.REDIS_MAX_CONNECTIONS - if use_clusters: - if max_conns: - return RedisCluster.from_url(pubsub_url, max_connections=max_conns) - else: - return RedisCluster.from_url(pubsub_url) + if use_clusters: + health_params = _get_cluster_connection_health_params() + kwargs: dict[str, Any] = {**health_params} + if max_conns: + kwargs["max_connections"] = max_conns + return RedisCluster.from_url(pubsub_url, **kwargs) + + health_params = _get_connection_health_params() + kwargs = {**health_params} if max_conns: - return redis.Redis.from_url(pubsub_url, max_connections=max_conns) - else: - return redis.Redis.from_url(pubsub_url) + kwargs["max_connections"] = max_conns + return redis.Redis.from_url(pubsub_url, **kwargs) def init_app(app: DifyApp): diff --git a/api/models/model.py b/api/models/model.py index 43ddf344d2..12865c4d22 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1632,52 +1632,53 @@ class Message(Base): files: list[File] = [] for message_file in message_files: - if message_file.transfer_method == FileTransferMethod.LOCAL_FILE: - if message_file.upload_file_id is None: - raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id") - file = file_factory.build_from_mapping( - mapping={ + match message_file.transfer_method: + case FileTransferMethod.LOCAL_FILE: + if message_file.upload_file_id is None: + raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "type": message_file.type, + "transfer_method": message_file.transfer_method, + "upload_file_id": message_file.upload_file_id, + }, + tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), + ) + case FileTransferMethod.REMOTE_URL: + if message_file.url is None: + raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "type": message_file.type, + "transfer_method": message_file.transfer_method, + "upload_file_id": message_file.upload_file_id, + "url": message_file.url, + }, + tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), + ) + case FileTransferMethod.TOOL_FILE: + if message_file.upload_file_id is None: + assert message_file.url is not None + message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0] + mapping = { "id": message_file.id, "type": message_file.type, "transfer_method": message_file.transfer_method, - "upload_file_id": message_file.upload_file_id, - }, - tenant_id=current_app.tenant_id, - access_controller=_get_file_access_controller(), - ) - elif message_file.transfer_method == FileTransferMethod.REMOTE_URL: - if message_file.url is None: - raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url") - file = file_factory.build_from_mapping( - mapping={ - "id": message_file.id, - "type": message_file.type, - "transfer_method": message_file.transfer_method, - "upload_file_id": message_file.upload_file_id, - "url": message_file.url, - }, - tenant_id=current_app.tenant_id, - access_controller=_get_file_access_controller(), - ) - elif message_file.transfer_method == FileTransferMethod.TOOL_FILE: - if message_file.upload_file_id is None: - assert message_file.url is not None - message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0] - mapping = { - "id": message_file.id, - "type": message_file.type, - "transfer_method": message_file.transfer_method, - "tool_file_id": message_file.upload_file_id, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=current_app.tenant_id, - access_controller=_get_file_access_controller(), - ) - else: - raise ValueError( - f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}" - ) + "tool_file_id": message_file.upload_file_id, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), + ) + case FileTransferMethod.DATASOURCE_FILE: + raise ValueError( + f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}" + ) files.append(file) result = cast( diff --git a/api/models/workflow.py b/api/models/workflow.py index 7edfb6aeef..347804b091 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1625,21 +1625,22 @@ class WorkflowDraftVariable(Base): # Rebuild them through the file factory so tenant ownership, signed URLs, # and storage-backed metadata come from canonical records instead of the # serialized JSON blob. - if segment_type == SegmentType.FILE: - if isinstance(value, File): - return build_segment_with_type(segment_type, value) - elif isinstance(value, dict): - file = self._rebuild_file_types(value) - return build_segment_with_type(segment_type, file) - else: - raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") - if segment_type == SegmentType.ARRAY_FILE: - if not isinstance(value, list): - raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") - file_list = self._rebuild_file_types(value) - return build_segment_with_type(segment_type=segment_type, value=file_list) - - return build_segment_with_type(segment_type=segment_type, value=value) + match segment_type: + case SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = self._rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + case SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = self._rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + case _: + return build_segment_with_type(segment_type=segment_type, value=value) @staticmethod def rebuild_file_types(value: Any): @@ -1672,21 +1673,22 @@ class WorkflowDraftVariable(Base): # Extends `variable_factory.build_segment_with_type` functionality by # reconstructing `FileSegment`` or `ArrayFileSegment`` objects from # their serialized dictionary or list representations, respectively. - if segment_type == SegmentType.FILE: - if isinstance(value, File): - return build_segment_with_type(segment_type, value) - elif isinstance(value, dict): - file = cls.rebuild_file_types(value) - return build_segment_with_type(segment_type, file) - else: - raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") - if segment_type == SegmentType.ARRAY_FILE: - if not isinstance(value, list): - raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") - file_list = cls.rebuild_file_types(value) - return build_segment_with_type(segment_type=segment_type, value=file_list) - - return build_segment_with_type(segment_type=segment_type, value=value) + match segment_type: + case SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = cls.rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + case SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = cls.rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + case _: + return build_segment_with_type(segment_type=segment_type, value=value) def get_value(self) -> Segment: """Decode the serialized value into its corresponding `Segment` object. diff --git a/api/schedule/clean_workflow_runlogs_precise.py b/api/schedule/clean_workflow_runlogs_precise.py index ebb8d52924..c5762fcdad 100644 --- a/api/schedule/clean_workflow_runlogs_precise.py +++ b/api/schedule/clean_workflow_runlogs_precise.py @@ -4,6 +4,7 @@ import time from collections.abc import Sequence import click +from sqlalchemy import delete, select from sqlalchemy.orm import Session, sessionmaker import app @@ -113,11 +114,9 @@ def _delete_batch( try: with session.begin_nested(): workflow_run_ids = [run.id for run in workflow_runs] - message_data = ( - session.query(Message.id, Message.conversation_id) - .where(Message.workflow_run_id.in_(workflow_run_ids)) - .all() - ) + message_data = session.execute( + select(Message.id, Message.conversation_id).where(Message.workflow_run_id.in_(workflow_run_ids)) + ).all() message_id_list = [msg.id for msg in message_data] conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id}) if message_id_list: @@ -132,23 +131,19 @@ def _delete_batch( SavedMessage, ] for model in message_related_models: - session.query(model).where(model.message_id.in_(message_id_list)).delete(synchronize_session=False) # type: ignore + session.execute(delete(model).where(model.message_id.in_(message_id_list))) # type: ignore # error: "DeclarativeAttributeIntercept" has no attribute "message_id". But this type is only in lib # and these 6 types all have the message_id field. - session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete( - synchronize_session=False - ) + session.execute(delete(Message).where(Message.workflow_run_id.in_(workflow_run_ids))) if conversation_id_list: - session.query(ConversationVariable).where( - ConversationVariable.conversation_id.in_(conversation_id_list) - ).delete(synchronize_session=False) - - session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete( - synchronize_session=False + session.execute( + delete(ConversationVariable).where(ConversationVariable.conversation_id.in_(conversation_id_list)) ) + session.execute(delete(Conversation).where(Conversation.id.in_(conversation_id_list))) + def _delete_node_executions(active_session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: run_ids = [run.id for run in runs] repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index b4a7fa051f..b0f7efaccd 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -120,7 +120,7 @@ class ClearFreePlanTenantExpiredLogs: apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all() app_ids = [app.id for app in apps] while True: - with Session(db.engine).no_autoflush as session: + with sessionmaker(bind=db.engine, autoflush=False).begin() as session: messages = ( session.query(Message) .where( @@ -152,7 +152,6 @@ class ClearFreePlanTenantExpiredLogs: ).delete(synchronize_session=False) cls._clear_message_related_tables(session, tenant_id, message_ids) - session.commit() click.echo( click.style( @@ -161,7 +160,7 @@ class ClearFreePlanTenantExpiredLogs: ) while True: - with Session(db.engine).no_autoflush as session: + with sessionmaker(bind=db.engine, autoflush=False).begin() as session: conversations = ( session.query(Conversation) .where( @@ -190,7 +189,6 @@ class ClearFreePlanTenantExpiredLogs: session.query(Conversation).where( Conversation.id.in_(conversation_ids), ).delete(synchronize_session=False) - session.commit() click.echo( click.style( @@ -294,7 +292,7 @@ class ClearFreePlanTenantExpiredLogs: break while True: - with Session(db.engine).no_autoflush as session: + with sessionmaker(bind=db.engine, autoflush=False).begin() as session: workflow_app_logs = ( session.query(WorkflowAppLog) .where( @@ -326,7 +324,6 @@ class ClearFreePlanTenantExpiredLogs: session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete( synchronize_session=False ) - session.commit() click.echo( click.style( diff --git a/api/services/end_user_service.py b/api/services/end_user_service.py index 29ada270ec..749d8dbc30 100644 --- a/api/services/end_user_service.py +++ b/api/services/end_user_service.py @@ -2,7 +2,7 @@ import logging from collections.abc import Mapping from sqlalchemy import case, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db @@ -24,7 +24,7 @@ class EndUserService: when an end-user ID is known. """ - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: return session.scalar( select(EndUser) .where( @@ -54,7 +54,7 @@ class EndUserService: if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Query with ORDER BY to prioritize exact type matches while maintaining backward compatibility # This single query approach is more efficient than separate queries end_user = session.scalar( @@ -82,7 +82,6 @@ class EndUserService: user_id, ) end_user.type = type - session.commit() else: # Create new end user if none exists end_user = EndUser( @@ -94,7 +93,6 @@ class EndUserService: external_user_id=user_id, ) session.add(end_user) - session.commit() return end_user @@ -135,7 +133,7 @@ class EndUserService: if not unique_app_ids: return result - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Fetch existing end users for all target apps in a single query existing_end_users: list[EndUser] = list( session.scalars( @@ -174,7 +172,6 @@ class EndUserService: ) session.add_all(new_end_users) - session.commit() for eu in new_end_users: result[eu.app_id] = eu diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py index 174bed488d..adbed87c3c 100644 --- a/api/services/plugin/plugin_auto_upgrade_service.py +++ b/api/services/plugin/plugin_auto_upgrade_service.py @@ -1,4 +1,4 @@ -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models.account import TenantPluginAutoUpgradeStrategy @@ -7,7 +7,7 @@ from models.account import TenantPluginAutoUpgradeStrategy class PluginAutoUpgradeService: @staticmethod def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: return ( session.query(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -23,7 +23,7 @@ class PluginAutoUpgradeService: exclude_plugins: list[str], include_plugins: list[str], ) -> bool: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: exist_strategy = ( session.query(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -46,12 +46,11 @@ class PluginAutoUpgradeService: exist_strategy.exclude_plugins = exclude_plugins exist_strategy.include_plugins = include_plugins - session.commit() return True @staticmethod def exclude_plugin(tenant_id: str, plugin_id: str) -> bool: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: exist_strategy = ( session.query(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -83,5 +82,4 @@ class PluginAutoUpgradeService: exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE exist_strategy.exclude_plugins = [plugin_id] - session.commit() return True diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py index 60fa269640..55276d6f99 100644 --- a/api/services/plugin/plugin_permission_service.py +++ b/api/services/plugin/plugin_permission_service.py @@ -1,4 +1,4 @@ -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models.account import TenantPluginPermission @@ -7,7 +7,7 @@ from models.account import TenantPluginPermission class PluginPermissionService: @staticmethod def get_permission(tenant_id: str) -> TenantPluginPermission | None: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() @staticmethod @@ -16,7 +16,7 @@ class PluginPermissionService: install_permission: TenantPluginPermission.InstallPermission, debug_permission: TenantPluginPermission.DebugPermission, ): - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: permission = ( session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() ) @@ -30,5 +30,4 @@ class PluginPermissionService: permission.install_permission = install_permission permission.debug_permission = debug_permission - session.commit() return True diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 04156713f4..e42c020925 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -5,7 +5,6 @@ import logging import uuid from collections.abc import Mapping from datetime import UTC, datetime -from enum import StrEnum from typing import cast from urllib.parse import urlparse from uuid import uuid4 @@ -38,6 +37,7 @@ from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode from models.workflow import Workflow, WorkflowType +from services.app_dsl_service import ImportMode, ImportStatus from services.entities.knowledge_entities.rag_pipeline_entities import ( IconInfo, KnowledgeConfiguration, @@ -54,18 +54,6 @@ DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB CURRENT_DSL_VERSION = "0.1.0" -class ImportMode(StrEnum): - YAML_CONTENT = "yaml-content" - YAML_URL = "yaml-url" - - -class ImportStatus(StrEnum): - COMPLETED = "completed" - COMPLETED_WITH_WARNINGS = "completed-with-warnings" - PENDING = "pending" - FAILED = "failed" - - class RagPipelineImportInfo(BaseModel): id: str status: ImportStatus diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py index 0e0dbab2d1..1e9f0bf149 100644 --- a/api/services/retention/conversation/messages_clean_service.py +++ b/api/services/retention/conversation/messages_clean_service.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, TypedDict, cast import sqlalchemy as sa from sqlalchemy import delete, select, tuple_ from sqlalchemy.engine import CursorResult -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from extensions.ext_database import db @@ -369,7 +369,7 @@ class MessagesCleanService: batch_deleted_messages = 0 # Step 1: Fetch a batch of messages using cursor - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: fetch_messages_start = time.monotonic() msg_stmt = ( select(Message.id, Message.app_id, Message.created_at) @@ -477,7 +477,7 @@ class MessagesCleanService: # Step 4: Batch delete messages and their relations if not self._dry_run: - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: delete_relations_start = time.monotonic() # Delete related records first self._batch_delete_message_relations(session, message_ids_to_delete) @@ -489,9 +489,7 @@ class MessagesCleanService: delete_result = cast(CursorResult, session.execute(delete_stmt)) messages_deleted = delete_result.rowcount delete_messages_ms = int((time.monotonic() - delete_messages_start) * 1000) - commit_start = time.monotonic() - session.commit() - commit_ms = int((time.monotonic() - commit_start) * 1000) + commit_ms = 0 stats["total_deleted"] += messages_deleted batch_deleted_messages = messages_deleted diff --git a/api/services/trigger/app_trigger_service.py b/api/services/trigger/app_trigger_service.py index 6d5a719f63..723d29e947 100644 --- a/api/services/trigger/app_trigger_service.py +++ b/api/services/trigger/app_trigger_service.py @@ -8,7 +8,7 @@ This service centralizes all AppTrigger-related business logic. import logging from sqlalchemy import update -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models.enums import AppTriggerStatus @@ -34,13 +34,12 @@ class AppTriggerService: """ try: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: session.execute( update(AppTrigger) .where(AppTrigger.tenant_id == tenant_id, AppTrigger.status == AppTriggerStatus.ENABLED) .values(status=AppTriggerStatus.RATE_LIMITED) ) - session.commit() logger.info("Marked all enabled triggers as rate limited for tenant %s", tenant_id) except Exception: logger.exception("Failed to mark all enabled triggers as rate limited for tenant %s", tenant_id) diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 008d8bdb8a..ae74f7a8cd 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -6,7 +6,7 @@ from collections.abc import Mapping from typing import Any from sqlalchemy import desc, func -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE @@ -146,7 +146,7 @@ class TriggerProviderService: """ try: provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Use distributed lock to prevent race conditions lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}" with redis_client.lock(lock_key, timeout=20): @@ -205,7 +205,6 @@ class TriggerProviderService: subscription.id = subscription_id or str(uuid.uuid4()) session.add(subscription) - session.commit() return { "result": "success", @@ -241,7 +240,7 @@ class TriggerProviderService: :param expires_at: Optional new expiration timestamp :return: Success response with updated subscription info """ - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Use distributed lock to prevent race conditions on the same subscription lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}" with redis_client.lock(lock_key, timeout=20): @@ -302,8 +301,6 @@ class TriggerProviderService: if expires_at is not None: subscription.expires_at = expires_at - session.commit() - # Clear subscription cache delete_cache_for_subscription( tenant_id=tenant_id, @@ -404,7 +401,7 @@ class TriggerProviderService: :param subscription_id: Subscription instance ID :return: New token info """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() if not subscription: @@ -448,7 +445,6 @@ class TriggerProviderService: # Update credentials subscription.credentials = dict(encrypter.encrypt(dict(refreshed_credentials.credentials))) subscription.credential_expires_at = refreshed_credentials.expires_at - session.commit() # Clear cache cache.delete() @@ -478,7 +474,7 @@ class TriggerProviderService: """ now_ts: int = int(now if now is not None else _time.time()) - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: subscription: TriggerSubscription | None = ( session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() ) @@ -531,7 +527,6 @@ class TriggerProviderService: # Persist refreshed properties and expires_at subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties))) subscription.expires_at = int(refreshed.expires_at) - session.commit() properties_cache.delete() logger.info( @@ -639,7 +634,7 @@ class TriggerProviderService: tenant_id=tenant_id, provider_id=provider_id ) - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # Find existing custom client params custom_client = ( session.query(TriggerOAuthTenantClient) @@ -683,8 +678,6 @@ class TriggerProviderService: if enabled is not None: custom_client.enabled = enabled - session.commit() - return {"result": "success"} @classmethod @@ -733,13 +726,12 @@ class TriggerProviderService: :param provider_id: Provider identifier :return: Success response """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: session.query(TriggerOAuthTenantClient).filter_by( tenant_id=tenant_id, provider=provider_id.provider_name, plugin_id=provider_id.plugin_id, ).delete() - session.commit() return {"result": "success"} diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index d72c041609..5a5d13b96d 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -8,7 +8,7 @@ from flask import Request, Response from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse @@ -215,7 +215,7 @@ class TriggerService: not_found_in_cache.append(node_info) continue - with Session(db.engine) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: try: # lock the concurrent plugin trigger creation redis_client.lock(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10) @@ -260,7 +260,6 @@ class TriggerService: cache.model_dump_json(), ex=60 * 60, ) - session.commit() # Update existing records if subscription_id changed for node_info in nodes_in_graph: @@ -290,14 +289,12 @@ class TriggerService: cache.model_dump_json(), ex=60 * 60, ) - session.commit() # delete the nodes not found in the graph for node_id in nodes_id_in_db: if node_id not in nodes_id_in_graph: session.delete(nodes_id_in_db[node_id]) redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{app.id}:{node_id}") - session.commit() except Exception: logger.exception("Failed to sync plugin trigger relationships for app %s", app.id) raise diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index f72c69a33e..7b69ccfce7 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -12,7 +12,7 @@ from graphon.file import FileTransferMethod from graphon.variables.types import ArrayValidation, SegmentType from pydantic import BaseModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from werkzeug.datastructures import FileStorage from werkzeug.exceptions import RequestEntityTooLarge @@ -597,21 +597,38 @@ class WebhookService: Raises: ValueError: If the value cannot be converted to the specified type """ - if param_type == SegmentType.STRING: - return value - elif param_type == SegmentType.NUMBER: - if not cls._can_convert_to_number(value): - raise ValueError(f"Cannot convert '{value}' to number") - numeric_value = float(value) - return int(numeric_value) if numeric_value.is_integer() else numeric_value - elif param_type == SegmentType.BOOLEAN: - lower_value = value.lower() - bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False} - if lower_value not in bool_map: - raise ValueError(f"Cannot convert '{value}' to boolean") - return bool_map[lower_value] - else: - raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'") + match param_type: + case SegmentType.STRING: + return value + case SegmentType.NUMBER: + if not cls._can_convert_to_number(value): + raise ValueError(f"Cannot convert '{value}' to number") + numeric_value = float(value) + return int(numeric_value) if numeric_value.is_integer() else numeric_value + case SegmentType.BOOLEAN: + lower_value = value.lower() + bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False} + if lower_value not in bool_map: + raise ValueError(f"Cannot convert '{value}' to boolean") + return bool_map[lower_value] + case ( + SegmentType.OBJECT + | SegmentType.FILE + | SegmentType.ARRAY_ANY + | SegmentType.ARRAY_STRING + | SegmentType.ARRAY_NUMBER + | SegmentType.ARRAY_OBJECT + | SegmentType.ARRAY_FILE + | SegmentType.ARRAY_BOOLEAN + | SegmentType.SECRET + | SegmentType.INTEGER + | SegmentType.FLOAT + | SegmentType.NONE + | SegmentType.GROUP + ): + raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'") + case _: + raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'") @classmethod def _validate_json_value(cls, param_name: str, value: Any, param_type: SegmentType | str) -> Any: @@ -912,7 +929,7 @@ class WebhookService: logger.warning("Failed to acquire lock for webhook sync, app %s", app.id) raise RuntimeError("Failed to acquire lock for webhook trigger synchronization") - with Session(db.engine) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # fetch the non-cached nodes from DB all_records = session.scalars( select(WorkflowWebhookTrigger).where( @@ -941,14 +958,12 @@ class WebhookService: redis_client.set( f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}", cache.model_dump_json(), ex=60 * 60 ) - session.commit() # delete the nodes not found in the graph for node_id in nodes_id_in_db: if node_id not in nodes_id_in_graph: session.delete(nodes_id_in_db[node_id]) redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}") - session.commit() except Exception: logger.exception("Failed to sync webhook relationships for app %s", app.id) raise diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index ae55c9ee03..c9d4673c0a 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import delete, select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -30,7 +31,9 @@ def add_document_to_index_task(dataset_document_id: str): start_at = time.perf_counter() with session_factory.create_session() as session: - dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first() + dataset_document = session.scalar( + select(DatasetDocument).where(DatasetDocument.id == dataset_document_id).limit(1) + ) if not dataset_document: logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red")) return @@ -45,15 +48,14 @@ def add_document_to_index_task(dataset_document_id: str): if not dataset: raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.") - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.status == SegmentStatus.COMPLETED, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() documents = [] multimodal_documents = [] @@ -104,18 +106,15 @@ def add_document_to_index_task(dataset_document_id: str): index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) # delete auto disable log - session.query(DatasetAutoDisableLog).where( - DatasetAutoDisableLog.document_id == dataset_document.id - ).delete() + session.execute( + delete(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id) + ) # update segment to enable - session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update( - { - DocumentSegment.enabled: True, - DocumentSegment.disabled_at: None, - DocumentSegment.disabled_by: None, - DocumentSegment.updated_at: naive_utc_now(), - } + session.execute( + update(DocumentSegment) + .where(DocumentSegment.document_id == dataset_document.id) + .values(enabled=True, disabled_at=None, disabled_by=None, updated_at=naive_utc_now()) ) session.commit() diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 66aafc30b9..56c371fcc1 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -1,9 +1,11 @@ import logging import time +from typing import cast import click from celery import shared_task from sqlalchemy import delete, select +from sqlalchemy.engine import CursorResult from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -92,14 +94,16 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form # ============ Step 3: Delete metadata binding (separate short transaction) ============ try: with session_factory.create_session() as session: - deleted_count = int( - session.query(DatasetMetadataBinding) - .where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id.in_(document_ids), - ) - .delete(synchronize_session=False) + result = cast( + CursorResult, + session.execute( + delete(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id.in_(document_ids), + ) + ), ) + deleted_count = result.rowcount session.commit() logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id) except Exception: diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index a017e9114b..a657cd553a 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -32,7 +32,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i with session_factory.create_session() as session: try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Document has no dataset") @@ -63,7 +63,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if index_node_ids: index_processor = IndexProcessorFactory(doc_form).init_index_processor() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if dataset: index_processor.clean( dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True @@ -94,7 +94,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i with session_factory.create_session() as session, session.begin(): if file_id: - file = session.query(UploadFile).where(UploadFile.id == file_id).first() + file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if file: try: storage.delete(file.key) @@ -124,10 +124,12 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i with session_factory.create_session() as session, session.begin(): # delete dataset metadata binding - session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id == document_id, - ).delete() + session.execute( + delete(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id == document_id, + ) + ) end_at = time.perf_counter() logger.info( diff --git a/api/tasks/delete_conversation_task.py b/api/tasks/delete_conversation_task.py index 9664b8ac73..0b392f6096 100644 --- a/api/tasks/delete_conversation_task.py +++ b/api/tasks/delete_conversation_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import delete from core.db.session_factory import session_factory from models import ConversationVariable @@ -29,29 +30,21 @@ def delete_conversation_related_data(conversation_id: str): with session_factory.create_session() as session: try: - session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete( - synchronize_session=False + session.execute(delete(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id)) + + session.execute(delete(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id)) + + session.execute( + delete(ToolConversationVariables).where(ToolConversationVariables.conversation_id == conversation_id) ) - session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete( - synchronize_session=False - ) + session.execute(delete(ToolFile).where(ToolFile.conversation_id == conversation_id)) - session.query(ToolConversationVariables).where( - ToolConversationVariables.conversation_id == conversation_id - ).delete(synchronize_session=False) + session.execute(delete(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id)) - session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False) + session.execute(delete(Message).where(Message.conversation_id == conversation_id)) - session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete( - synchronize_session=False - ) - - session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False) - - session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False - ) + session.execute(delete(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id)) session.commit() diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index a6a2dcebc8..306a23aeda 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -3,7 +3,7 @@ import time import click from celery import shared_task -from sqlalchemy import delete +from sqlalchemy import delete, select from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -29,12 +29,12 @@ def delete_segment_from_index_task( start_at = time.perf_counter() with session_factory.create_session() as session: try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) return - dataset_document = session.query(Document).where(Document.id == document_id).first() + dataset_document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if not dataset_document: return @@ -60,11 +60,9 @@ def delete_segment_from_index_task( ) if dataset.is_multimodal: # delete segment attachment binding - segment_attachment_bindings = ( - session.query(SegmentAttachmentBinding) - .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) - .all() - ) + segment_attachment_bindings = session.scalars( + select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + ).all() if segment_attachment_bindings: attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) @@ -77,7 +75,7 @@ def delete_segment_from_index_task( session.execute(segment_attachment_bind_delete_stmt) # delete upload file - session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) + session.execute(delete(UploadFile).where(UploadFile.id.in_(attachment_ids))) session.commit() end_at = time.perf_counter() diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 3cc267e821..86e96ea3f0 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -3,7 +3,7 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -27,12 +27,12 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen start_at = time.perf_counter() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) return - dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + dataset_document = session.scalar(select(DatasetDocument).where(DatasetDocument.id == document_id).limit(1)) if not dataset_document: logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) @@ -58,11 +58,9 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen index_node_ids = [segment.index_node_id for segment in segments] if dataset.is_multimodal: segment_ids = [segment.id for segment in segments] - segment_attachment_bindings = ( - session.query(SegmentAttachmentBinding) - .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) - .all() - ) + segment_attachment_bindings = session.scalars( + select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + ).all() if segment_attachment_bindings: attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] index_node_ids.extend(attachment_ids) @@ -87,16 +85,14 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) except Exception: # update segment error msg - session.query(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ).update( - { - "disabled_at": None, - "disabled_by": None, - "enabled": True, - } + session.execute( + update(DocumentSegment) + .where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .values(disabled_at=None, disabled_by=None, enabled=True) ) session.commit() finally: diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index 6f490ab7ea..e794195c92 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -47,7 +47,7 @@ def regenerate_summary_index_task( try: with session_factory.create_session() as session: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) return @@ -84,8 +84,8 @@ def regenerate_summary_index_task( # For embedding_model change: directly query all segments with existing summaries # Don't require document indexing_status == "completed" # Include summaries with status "completed" or "error" (if they have content) - segments_with_summaries = ( - session.query(DocumentSegment, DocumentSegmentSummary) + segments_with_summaries = session.execute( + select(DocumentSegment, DocumentSegmentSummary) .join( DocumentSegmentSummary, DocumentSegment.id == DocumentSegmentSummary.chunk_id, @@ -110,8 +110,7 @@ def regenerate_summary_index_task( DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) - .all() - ) + ).all() if not segments_with_summaries: logger.info( @@ -215,8 +214,8 @@ def regenerate_summary_index_task( try: # Get all segments with existing summaries - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .join( DocumentSegmentSummary, DocumentSegment.id == DocumentSegmentSummary.chunk_id, @@ -229,8 +228,7 @@ def regenerate_summary_index_task( DocumentSegmentSummary.dataset_id == dataset_id, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if not segments: continue @@ -245,13 +243,13 @@ def regenerate_summary_index_task( summary_record = None try: # Get existing summary record - summary_record = ( - session.query(DocumentSegmentSummary) - .filter_by( - chunk_id=segment.id, - dataset_id=dataset_id, + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset_id, ) - .first() + .limit(1) ) if not summary_record: diff --git a/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py b/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py new file mode 100644 index 0000000000..ccc4188dbf --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +import uuid +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models.model import AccountTrialAppRecord, TrialApp +from services import recommended_app_service as service_module +from services.recommended_app_service import RecommendedAppService + +# ── Helpers ──────────────────────────────────────────────────────────── + + +def _apps_response( + recommended_apps: list[dict] | None = None, + categories: list[str] | None = None, +) -> dict: + if recommended_apps is None: + recommended_apps = [ + {"id": "app-1", "name": "Test App 1", "description": "d1", "category": "productivity"}, + {"id": "app-2", "name": "Test App 2", "description": "d2", "category": "communication"}, + ] + if categories is None: + categories = ["productivity", "communication", "utilities"] + return {"recommended_apps": recommended_apps, "categories": categories} + + +def _app_detail( + app_id: str = "app-123", + name: str = "Test App", + description: str = "Test description", + **kwargs: Any, +) -> dict: + detail: dict[str, Any] = { + "id": app_id, + "name": name, + "description": description, + "category": kwargs.get("category", "productivity"), + "icon": kwargs.get("icon", "🚀"), + "model_config": kwargs.get("model_config", {}), + } + detail.update(kwargs) + return detail + + +def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any] | None: + return cast("dict[str, Any] | None", result) + + +def _mock_factory_for_apps( + monkeypatch: pytest.MonkeyPatch, + *, + mode: str, + result: dict[str, Any], + fallback_result: dict[str, Any] | None = None, +) -> tuple[MagicMock, MagicMock]: + retrieval_instance = MagicMock() + retrieval_instance.get_recommended_apps_and_categories.return_value = result + retrieval_factory = MagicMock(return_value=retrieval_instance) + monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False) + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_recommend_app_factory", + MagicMock(return_value=retrieval_factory), + ) + builtin_instance = MagicMock() + if fallback_result is not None: + builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_buildin_recommend_app_retrieval", + MagicMock(return_value=builtin_instance), + ) + return retrieval_instance, builtin_instance + + +# ── Pure logic tests: get_recommended_apps_and_categories ────────────── + + +class TestRecommendedAppServiceGetApps: + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_success_with_apps(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + expected = _apps_response() + + mock_instance = MagicMock() + mock_instance.get_recommended_apps_and_categories.return_value = expected + mock_factory = MagicMock(return_value=mock_instance) + mock_factory_class.get_recommend_app_factory.return_value = mock_factory + + result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + + assert result == expected + assert len(result["recommended_apps"]) == 2 + assert len(result["categories"]) == 3 + mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote") + mock_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + empty_response = {"recommended_apps": [], "categories": []} + builtin_response = _apps_response( + recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}] + ) + + mock_remote_instance = MagicMock() + mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_remote_instance) + + mock_builtin_instance = MagicMock() + mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response + mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance + + result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN") + + assert result == builtin_response + assert result["recommended_apps"][0]["id"] == "builtin-1" + mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db" + none_response = {"recommended_apps": None, "categories": ["test"]} + builtin_response = _apps_response() + + mock_db_instance = MagicMock() + mock_db_instance.get_recommended_apps_and_categories.return_value = none_response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_db_instance) + + mock_builtin_instance = MagicMock() + mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response + mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance + + result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + + assert result == builtin_response + mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once() + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_different_languages(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin" + + for language in ["en-US", "zh-CN", "ja-JP", "fr-FR"]: + lang_response = _apps_response( + recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}] + ) + mock_instance = MagicMock() + mock_instance.get_recommended_apps_and_categories.return_value = lang_response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = RecommendedAppService.get_recommended_apps_and_categories(language) + + assert result["recommended_apps"][0]["id"] == f"app-{language}" + mock_instance.get_recommended_apps_and_categories.assert_called_with(language) + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_uses_correct_factory_mode(self, mock_config, mock_factory_class): + for mode in ["remote", "builtin", "db"]: + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode + response = _apps_response() + mock_instance = MagicMock() + mock_instance.get_recommended_apps_and_categories.return_value = response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + RecommendedAppService.get_recommended_apps_and_categories("en-US") + + mock_factory_class.get_recommend_app_factory.assert_called_with(mode) + + +# ── Pure logic tests: get_recommend_app_detail ───────────────────────── + + +class TestRecommendedAppServiceGetDetail: + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_success(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + expected = _app_detail(app_id="app-123", name="Productivity App", description="A great app") + + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = expected + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("app-123")) + + assert result == expected + assert result["id"] == "app-123" + mock_instance.get_recommend_app_detail.assert_called_once_with("app-123") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_different_modes(self, mock_config, mock_factory_class): + for mode in ["remote", "builtin", "db"]: + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode + detail = _app_detail(app_id="test-app", name=f"App from {mode}") + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = detail + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("test-app")) + + assert result["name"] == f"App from {mode}" + mock_factory_class.get_recommend_app_factory.assert_called_with(mode) + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_returns_none_when_not_found(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = None + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("nonexistent")) + + assert result is None + mock_instance.get_recommend_app_detail.assert_called_once_with("nonexistent") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_returns_empty_dict(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin" + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = {} + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("app-empty")) + + assert result == {} + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_complex_model_config(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + complex_config = { + "provider": "openai", + "model": "gpt-4", + "parameters": {"temperature": 0.7, "max_tokens": 2000, "top_p": 1.0}, + } + expected = _app_detail( + app_id="complex-app", + name="Complex App", + model_config=complex_config, + workflows=["workflow-1", "workflow-2"], + tools=["tool-1", "tool-2", "tool-3"], + ) + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = expected + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("complex-app")) + + assert result["model_config"] == complex_config + assert len(result["workflows"]) == 2 + assert len(result["tools"]) == 3 + + +# ── Integration tests: trial app features (real DB) ──────────────────── + + +class TestRecommendedAppServiceTrialFeatures: + def test_get_apps_should_not_query_trial_table_when_disabled( + self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch + ): + expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]} + retrieval_instance, builtin_instance = _mock_factory_for_apps(monkeypatch, mode="remote", result=expected) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=False)), + ) + + result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + + assert result == expected + retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US") + builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called() + + def test_get_apps_should_enrich_can_trial_when_enabled( + self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch + ): + app_id_1 = str(uuid.uuid4()) + app_id_2 = str(uuid.uuid4()) + tenant_id = str(uuid.uuid4()) + + # app_id_1 has a TrialApp record; app_id_2 does not + db_session_with_containers.add(TrialApp(app_id=app_id_1, tenant_id=tenant_id)) + db_session_with_containers.commit() + + remote_result = {"recommended_apps": [], "categories": []} + fallback_result = { + "recommended_apps": [{"app_id": app_id_1}, {"app_id": app_id_2}], + "categories": ["all"], + } + _, builtin_instance = _mock_factory_for_apps( + monkeypatch, mode="remote", result=remote_result, fallback_result=fallback_result + ) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), + ) + + result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP") + + builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US") + assert result["recommended_apps"][0]["can_trial"] is True + assert result["recommended_apps"][1]["can_trial"] is False + + @pytest.mark.parametrize("has_trial_app", [True, False]) + def test_get_detail_should_set_can_trial_when_enabled( + self, + db_session_with_containers: Session, + monkeypatch: pytest.MonkeyPatch, + has_trial_app: bool, + ): + app_id = str(uuid.uuid4()) + tenant_id = str(uuid.uuid4()) + + if has_trial_app: + db_session_with_containers.add(TrialApp(app_id=app_id, tenant_id=tenant_id)) + db_session_with_containers.commit() + + detail = {"id": app_id, "name": "Test App"} + retrieval_instance = MagicMock() + retrieval_instance.get_recommend_app_detail.return_value = detail + retrieval_factory = MagicMock(return_value=retrieval_instance) + monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False) + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_recommend_app_factory", + MagicMock(return_value=retrieval_factory), + ) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), + ) + + result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail(app_id)) + + assert result["id"] == app_id + assert result["can_trial"] is has_trial_app + + def test_add_trial_app_record_increments_count_for_existing(self, db_session_with_containers: Session): + app_id = str(uuid.uuid4()) + account_id = str(uuid.uuid4()) + + db_session_with_containers.add(AccountTrialAppRecord(app_id=app_id, account_id=account_id, count=3)) + db_session_with_containers.commit() + + RecommendedAppService.add_trial_app_record(app_id, account_id) + + db_session_with_containers.expire_all() + record = db_session_with_containers.scalar( + select(AccountTrialAppRecord) + .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id) + .limit(1) + ) + assert record is not None + assert record.count == 4 + + def test_add_trial_app_record_creates_new_record(self, db_session_with_containers: Session): + app_id = str(uuid.uuid4()) + account_id = str(uuid.uuid4()) + + RecommendedAppService.add_trial_app_record(app_id, account_id) + + db_session_with_containers.expire_all() + record = db_session_with_containers.scalar( + select(AccountTrialAppRecord) + .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id) + .limit(1) + ) + assert record is not None + assert record.app_id == app_id + assert record.account_id == account_id + assert record.count == 1 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 061719d15a..1fb0dc6cf1 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -134,6 +134,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -150,7 +151,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -177,7 +180,6 @@ class TestAdvancedChatAppRunnerConversationVariables: # Note: Since we're mocking ConversationVariable.from_variable, # we can't directly check the id, but we can verify add_all was called assert mock_session.add_all.called, "Session add_all should have been called" - assert mock_session.commit.called, "Session commit should have been called" def test_no_variables_creates_all(self): """Test that all conversation variables are created when none exist in DB.""" @@ -278,6 +280,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -295,7 +298,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -326,7 +331,6 @@ class TestAdvancedChatAppRunnerConversationVariables: # Verify that all variables were created assert len(added_items) == 2, "Should have added both variables" assert mock_session.add_all.called, "Session add_all should have been called" - assert mock_session.commit.called, "Session commit should have been called" def test_all_variables_exist_no_changes(self): """Test that no changes are made when all variables already exist in DB.""" @@ -429,6 +433,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -445,7 +450,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -465,4 +472,3 @@ class TestAdvancedChatAppRunnerConversationVariables: # Verify that no variables were added assert not mock_session.add_all.called, "Session add_all should not have been called" - assert mock_session.commit.called, "Session commit should still be called" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 079df0b4e6..5d8faee897 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -93,6 +93,16 @@ def _patch_common_run_deps(runner: AdvancedChatAppRunner): scalar=lambda *a, **k: MagicMock(), ), ), + sessionmaker=MagicMock( + return_value=MagicMock( + begin=MagicMock( + return_value=MagicMock( + __enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))), + __exit__=lambda *a, **k: False, + ), + ), + ), + ), select=MagicMock(), db=MagicMock(engine=MagicMock()), RedisChannel=MagicMock(), diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index dabd2594b4..d91bb85aee 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -2,6 +2,7 @@ from __future__ import annotations from contextlib import contextmanager from types import SimpleNamespace +from unittest.mock import MagicMock import pytest from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus @@ -610,33 +611,33 @@ class TestWorkflowGenerateTaskPipeline: def test_database_session_rolls_back_on_error(self, monkeypatch): pipeline = _make_pipeline() - calls = {"commit": 0, "rollback": 0} - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs + calls = {"enter": 0, "exit_exc": None} + class _BeginContext: def __enter__(self): - return self + calls["enter"] += 1 + return MagicMock() def __exit__(self, exc_type, exc, tb): + calls["exit_exc"] = exc_type return False - def commit(self): - calls["commit"] += 1 + class _Sessionmaker: + def __init__(self, *args, **kwargs): + pass - def rollback(self): - calls["rollback"] += 1 + def begin(self): + return _BeginContext() - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.sessionmaker", _Sessionmaker) monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) with pytest.raises(RuntimeError, match="db error"): with pipeline._database_session(): raise RuntimeError("db error") - assert calls["commit"] == 0 - assert calls["rollback"] == 1 + assert calls["enter"] == 1 + assert calls["exit_exc"] is RuntimeError def test_node_retry_and_started_handlers_cover_none_and_value(self): pipeline = _make_pipeline() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py index 545565cdf4..d4fa4b3e8e 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py @@ -71,7 +71,7 @@ def test_vector_methods_delegate_to_underlying_implementation(): assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value vector.delete() - runner._create_collection_if_not_exists.assert_called_once_with(2) + runner.create_collection_if_not_exists.assert_called_once_with(2) runner.add_texts.assert_any_call(texts, [[0.1, 0.2]]) runner.delete_by_ids.assert_called_once_with(["d1"]) runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py index 45777774d0..4f8653a926 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py @@ -249,7 +249,7 @@ def test_create_collection_if_not_exists_creates_when_missing(monkeypatch): vector._client = MagicMock() vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404) - vector._create_collection_if_not_exists(embedding_dimension=1024) + vector.create_collection_if_not_exists(embedding_dimension=1024) vector._client.create_collection.assert_called_once() openapi_module.redis_client.set.assert_called_once() @@ -268,7 +268,7 @@ def test_create_collection_if_not_exists_skips_when_cached(monkeypatch): vector.config = _config() vector._client = MagicMock() - vector._create_collection_if_not_exists(embedding_dimension=1024) + vector.create_collection_if_not_exists(embedding_dimension=1024) vector._client.describe_collection.assert_not_called() vector._client.create_collection.assert_not_called() @@ -290,7 +290,7 @@ def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch): vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500) with pytest.raises(ValueError, match="failed to create collection collection_1"): - vector._create_collection_if_not_exists(embedding_dimension=512) + vector.create_collection_if_not_exists(embedding_dimension=512) def test_openapi_add_delete_and_search_methods(monkeypatch): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py index 8f1206696b..f798ef8bd1 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py @@ -374,7 +374,7 @@ def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeyp vector._get_cursor = _cursor_context - vector._create_collection_if_not_exists(embedding_dimension=3) + vector.create_collection_if_not_exists(embedding_dimension=3) assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list) assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list) @@ -404,7 +404,7 @@ def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypat vector._get_cursor = _cursor_context with pytest.raises(RuntimeError, match="permission denied"): - vector._create_collection_if_not_exists(embedding_dimension=3) + vector.create_collection_if_not_exists(embedding_dimension=3) def test_delete_methods_raise_when_error_is_not_missing_table(): diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 40d138df90..b98fec3854 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -4909,15 +4909,17 @@ class TestInternalHooksCoverage: session_ctx.__enter__.return_value = session session_ctx.__exit__.return_value = False + sessionmaker_ctx = MagicMock() + sessionmaker_ctx.begin.return_value = session_ctx + with ( patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())), - patch("core.rag.retrieval.dataset_retrieval.Session", return_value=session_ctx), + patch("core.rag.retrieval.dataset_retrieval.sessionmaker", return_value=sessionmaker_ctx), patch.object(retrieval, "_send_trace_task") as mock_trace, ): retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1}) query.update.assert_called_once() - session.commit.assert_called_once() mock_trace.assert_called_once() def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None: diff --git a/api/tests/unit_tests/extensions/test_redis.py b/api/tests/unit_tests/extensions/test_redis.py index 933fa32894..5e9be4ab9b 100644 --- a/api/tests/unit_tests/extensions/test_redis.py +++ b/api/tests/unit_tests/extensions/test_redis.py @@ -1,53 +1,125 @@ +from unittest.mock import patch + from redis import RedisError +from redis.retry import Retry -from extensions.ext_redis import redis_fallback +from extensions.ext_redis import ( + _get_base_redis_params, + _get_cluster_connection_health_params, + _get_connection_health_params, + redis_fallback, +) -def test_redis_fallback_success(): - @redis_fallback(default_return=None) - def test_func(): - return "success" +class TestGetConnectionHealthParams: + @patch("extensions.ext_redis.dify_config") + def test_includes_all_health_params(self, mock_config): + mock_config.REDIS_RETRY_RETRIES = 3 + mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0 + mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0 + mock_config.REDIS_SOCKET_TIMEOUT = 5.0 + mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0 + mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30 - assert test_func() == "success" + params = _get_connection_health_params() + + assert "retry" in params + assert "socket_timeout" in params + assert "socket_connect_timeout" in params + assert "health_check_interval" in params + assert isinstance(params["retry"], Retry) + assert params["retry"]._retries == 3 + assert params["socket_timeout"] == 5.0 + assert params["socket_connect_timeout"] == 5.0 + assert params["health_check_interval"] == 30 -def test_redis_fallback_error(): - @redis_fallback(default_return="fallback") - def test_func(): - raise RedisError("Redis error") +class TestGetClusterConnectionHealthParams: + @patch("extensions.ext_redis.dify_config") + def test_excludes_health_check_interval(self, mock_config): + mock_config.REDIS_RETRY_RETRIES = 3 + mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0 + mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0 + mock_config.REDIS_SOCKET_TIMEOUT = 5.0 + mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0 + mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30 - assert test_func() == "fallback" + params = _get_cluster_connection_health_params() + + assert "retry" in params + assert "socket_timeout" in params + assert "socket_connect_timeout" in params + assert "health_check_interval" not in params -def test_redis_fallback_none_default(): - @redis_fallback() - def test_func(): - raise RedisError("Redis error") +class TestGetBaseRedisParams: + @patch("extensions.ext_redis.dify_config") + def test_includes_retry_and_health_params(self, mock_config): + mock_config.REDIS_USERNAME = None + mock_config.REDIS_PASSWORD = None + mock_config.REDIS_DB = 0 + mock_config.REDIS_SERIALIZATION_PROTOCOL = 3 + mock_config.REDIS_ENABLE_CLIENT_SIDE_CACHE = False + mock_config.REDIS_RETRY_RETRIES = 3 + mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0 + mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0 + mock_config.REDIS_SOCKET_TIMEOUT = 5.0 + mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0 + mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30 - assert test_func() is None + params = _get_base_redis_params() + + assert "retry" in params + assert isinstance(params["retry"], Retry) + assert params["socket_timeout"] == 5.0 + assert params["socket_connect_timeout"] == 5.0 + assert params["health_check_interval"] == 30 + # Existing params still present + assert params["db"] == 0 + assert params["encoding"] == "utf-8" -def test_redis_fallback_with_args(): - @redis_fallback(default_return=0) - def test_func(x, y): - raise RedisError("Redis error") +class TestRedisFallback: + def test_redis_fallback_success(self): + @redis_fallback(default_return=None) + def test_func(): + return "success" - assert test_func(1, 2) == 0 + assert test_func() == "success" + def test_redis_fallback_error(self): + @redis_fallback(default_return="fallback") + def test_func(): + raise RedisError("Redis error") -def test_redis_fallback_with_kwargs(): - @redis_fallback(default_return={}) - def test_func(x=None, y=None): - raise RedisError("Redis error") + assert test_func() == "fallback" - assert test_func(x=1, y=2) == {} + def test_redis_fallback_none_default(self): + @redis_fallback() + def test_func(): + raise RedisError("Redis error") + assert test_func() is None -def test_redis_fallback_preserves_function_metadata(): - @redis_fallback(default_return=None) - def test_func(): - """Test function docstring""" - pass + def test_redis_fallback_with_args(self): + @redis_fallback(default_return=0) + def test_func(x, y): + raise RedisError("Redis error") - assert test_func.__name__ == "test_func" - assert test_func.__doc__ == "Test function docstring" + assert test_func(1, 2) == 0 + + def test_redis_fallback_with_kwargs(self): + @redis_fallback(default_return={}) + def test_func(x=None, y=None): + raise RedisError("Redis error") + + assert test_func(x=1, y=2) == {} + + def test_redis_fallback_preserves_function_metadata(self): + @redis_fallback(default_return=None) + def test_func(): + """Test function docstring""" + pass + + assert test_func.__name__ == "test_func" + assert test_func.__doc__ == "Test function docstring" diff --git a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py index edb50d09a6..45156958b6 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py @@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_auto_upgrade_service" def _patched_session(): - """Patch Session(db.engine) to return a mock session as context manager.""" + """Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" session = MagicMock() - session_cls = MagicMock() - session_cls.return_value.__enter__ = MagicMock(return_value=session) - session_cls.return_value.__exit__ = MagicMock(return_value=False) - patcher = patch(f"{MODULE}.Session", session_cls) + mock_sessionmaker = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) db_patcher = patch(f"{MODULE}.db") return patcher, db_patcher, session @@ -61,7 +61,6 @@ class TestChangeStrategy: assert result is True session.add.assert_called_once() - session.commit.assert_called_once() def test_updates_existing_strategy(self): p1, p2, session = _patched_session() @@ -86,7 +85,6 @@ class TestChangeStrategy: assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL assert existing.exclude_plugins == ["p1"] assert existing.include_plugins == ["p2"] - session.commit.assert_called_once() class TestExcludePlugin: @@ -127,7 +125,6 @@ class TestExcludePlugin: assert result is True assert existing.exclude_plugins == ["p-existing", "p-new"] - session.commit.assert_called_once() def test_removes_from_include_list_in_partial_mode(self): p1, p2, session = _patched_session() diff --git a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py index 69091110db..40f4c6a8d2 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py @@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_permission_service" def _patched_session(): - """Patch Session(db.engine) to return a mock session as context manager.""" + """Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" session = MagicMock() - session_cls = MagicMock() - session_cls.return_value.__enter__ = MagicMock(return_value=session) - session_cls.return_value.__exit__ = MagicMock(return_value=False) - patcher = patch(f"{MODULE}.Session", session_cls) + mock_sessionmaker = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) db_patcher = patch(f"{MODULE}.db") return patcher, db_patcher, session @@ -55,7 +55,6 @@ class TestChangePermission: ) session.add.assert_called_once() - session.commit.assert_called_once() def test_updates_existing_permission(self): p1, p2, session = _patched_session() @@ -71,5 +70,4 @@ class TestChangePermission: assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS - session.commit.assert_called_once() session.add.assert_not_called() diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py index f393a4b10b..3e989c55a3 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -275,48 +275,46 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) - msg_session_1.query.side_effect = lambda model: ( make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock() ) - msg_session_1.commit.return_value = None - msg_session_2 = MagicMock() msg_session_2.query.side_effect = lambda model: ( make_query_with_batches([[]]) if model == service_module.Message else MagicMock() ) - msg_session_2.commit.return_value = None conv_session_1 = MagicMock() conv_session_1.query.side_effect = lambda model: ( make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock() ) - conv_session_1.commit.return_value = None conv_session_2 = MagicMock() conv_session_2.query.side_effect = lambda model: ( make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock() ) - conv_session_2.commit.return_value = None wal_session_1 = MagicMock() wal_session_1.query.side_effect = lambda model: ( make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock() ) - wal_session_1.commit.return_value = None wal_session_2 = MagicMock() wal_session_2.query.side_effect = lambda model: ( make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock() ) - wal_session_2.commit.return_value = None session_wrappers = [ - _session_wrapper_for_no_autoflush(msg_session_1), - _session_wrapper_for_no_autoflush(msg_session_2), - _session_wrapper_for_no_autoflush(conv_session_1), - _session_wrapper_for_no_autoflush(conv_session_2), - _session_wrapper_for_no_autoflush(wal_session_1), - _session_wrapper_for_no_autoflush(wal_session_2), + _sessionmaker_wrapper_for_begin(msg_session_1), + _sessionmaker_wrapper_for_begin(msg_session_2), + _sessionmaker_wrapper_for_begin(conv_session_1), + _sessionmaker_wrapper_for_begin(conv_session_2), + _sessionmaker_wrapper_for_begin(wal_session_1), + _sessionmaker_wrapper_for_begin(wal_session_2), ] - monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0)) + def fake_sessionmaker(*args, **kwargs): + if kwargs.get("autoflush") is False: + return session_wrappers.pop(0) + return object() + + monkeypatch.setattr(service_module, "sessionmaker", fake_sessionmaker) def fake_select(*_args, **_kwargs): stmt = MagicMock() @@ -333,8 +331,6 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) - run_repo = MagicMock() run_repo.get_expired_runs_batch.side_effect = [[SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"})], []] run_repo.delete_runs_by_ids.return_value = 1 - - monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object()) monkeypatch.setattr( service_module.DifyAPIRepositoryFactory, "create_api_workflow_node_execution_repository", @@ -574,13 +570,18 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte q_empty.limit.return_value = q_empty q_empty.all.return_value = [] empty_session.query.return_value = q_empty - empty_session.commit.return_value = None session_wrappers = [ - _session_wrapper_for_no_autoflush(empty_session), - _session_wrapper_for_no_autoflush(empty_session), - _session_wrapper_for_no_autoflush(empty_session), + _sessionmaker_wrapper_for_begin(empty_session), + _sessionmaker_wrapper_for_begin(empty_session), + _sessionmaker_wrapper_for_begin(empty_session), ] - monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0)) + + def fake_sessionmaker(*args, **kwargs): + if kwargs.get("autoflush") is False: + return session_wrappers.pop(0) + return object() + + monkeypatch.setattr(service_module, "sessionmaker", fake_sessionmaker) def fake_select(*_args, **_kwargs): stmt = MagicMock() @@ -606,8 +607,6 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte [], ] run_repo.delete_runs_by_ids.return_value = 2 - - monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object()) monkeypatch.setattr( service_module.DifyAPIRepositoryFactory, "create_api_workflow_node_execution_repository", diff --git a/api/tests/unit_tests/services/test_recommended_app_service.py b/api/tests/unit_tests/services/test_recommended_app_service.py deleted file mode 100644 index 12bc84db87..0000000000 --- a/api/tests/unit_tests/services/test_recommended_app_service.py +++ /dev/null @@ -1,628 +0,0 @@ -""" -Comprehensive unit tests for RecommendedAppService. - -This test suite provides complete coverage of recommended app operations in Dify, -following TDD principles with the Arrange-Act-Assert pattern. - -## Test Coverage - -### 1. Get Recommended Apps and Categories (TestRecommendedAppServiceGetApps) -Tests fetching recommended apps with categories: -- Successful retrieval with recommended apps -- Fallback to builtin when no recommended apps -- Different language support -- Factory mode selection (remote, builtin, db) -- Empty result handling - -### 2. Get Recommend App Detail (TestRecommendedAppServiceGetDetail) -Tests fetching individual app details: -- Successful app detail retrieval -- Different factory modes -- App not found scenarios -- Language-specific details - -## Testing Approach - -- **Mocking Strategy**: All external dependencies (dify_config, RecommendAppRetrievalFactory) - are mocked for fast, isolated unit tests -- **Factory Pattern**: Tests verify correct factory selection based on mode -- **Fixtures**: Mock objects are configured per test method -- **Assertions**: Each test verifies return values and factory method calls - -## Key Concepts - -**Factory Modes:** -- remote: Fetch from remote API -- builtin: Use built-in templates -- db: Fetch from database - -**Fallback Logic:** -- If remote/db returns no apps, fallback to builtin en-US templates -- Ensures users always see some recommended apps -""" - -from unittest.mock import MagicMock, patch - -import pytest - -from services.recommended_app_service import RecommendedAppService - - -class RecommendedAppServiceTestDataFactory: - """ - Factory for creating test data and mock objects. - - Provides reusable methods to create consistent mock objects for testing - recommended app operations. - """ - - @staticmethod - def create_recommended_apps_response( - recommended_apps: list[dict] | None = None, - categories: list[str] | None = None, - ) -> dict: - """ - Create a mock response for recommended apps. - - Args: - recommended_apps: List of recommended app dictionaries - categories: List of category names - - Returns: - Dictionary with recommended_apps and categories - """ - if recommended_apps is None: - recommended_apps = [ - { - "id": "app-1", - "name": "Test App 1", - "description": "Test description 1", - "category": "productivity", - }, - { - "id": "app-2", - "name": "Test App 2", - "description": "Test description 2", - "category": "communication", - }, - ] - if categories is None: - categories = ["productivity", "communication", "utilities"] - - return { - "recommended_apps": recommended_apps, - "categories": categories, - } - - @staticmethod - def create_app_detail_response( - app_id: str = "app-123", - name: str = "Test App", - description: str = "Test description", - **kwargs, - ) -> dict: - """ - Create a mock response for app detail. - - Args: - app_id: App identifier - name: App name - description: App description - **kwargs: Additional fields - - Returns: - Dictionary with app details - """ - detail = { - "id": app_id, - "name": name, - "description": description, - "category": kwargs.get("category", "productivity"), - "icon": kwargs.get("icon", "🚀"), - "model_config": kwargs.get("model_config", {}), - } - detail.update(kwargs) - return detail - - -@pytest.fixture -def factory(): - """Provide the test data factory to all tests.""" - return RecommendedAppServiceTestDataFactory - - -class TestRecommendedAppServiceGetApps: - """Test get_recommended_apps_and_categories operations.""" - - @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) - @patch("services.recommended_app_service.dify_config", autospec=True) - def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory): - """Test successful retrieval of recommended apps when apps are returned.""" - # Arrange - mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" - - expected_response = factory.create_recommended_apps_response() - - # Mock factory and retrieval instance - mock_retrieval_instance = MagicMock() - mock_retrieval_instance.get_recommended_apps_and_categories.return_value = expected_response - - mock_factory = MagicMock() - mock_factory.return_value = mock_retrieval_instance - mock_factory_class.get_recommend_app_factory.return_value = mock_factory - - # Act - result = RecommendedAppService.get_recommended_apps_and_categories("en-US") - - # Assert - assert result == expected_response - assert len(result["recommended_apps"]) == 2 - assert len(result["categories"]) == 3 - mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote") - mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US") - - @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) - @patch("services.recommended_app_service.dify_config", autospec=True) - def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory): - """Test fallback to builtin when no recommended apps are returned.""" - # Arrange - mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" - - # Remote returns empty recommended_apps - empty_response = {"recommended_apps": [], "categories": []} - - # Builtin fallback response - builtin_response = factory.create_recommended_apps_response( - recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}] - ) - - # Mock remote retrieval instance (returns empty) - mock_remote_instance = MagicMock() - mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response - - mock_remote_factory = MagicMock() - mock_remote_factory.return_value = mock_remote_instance - mock_factory_class.get_recommend_app_factory.return_value = mock_remote_factory - - # Mock builtin retrieval instance - mock_builtin_instance = MagicMock() - mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response - mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance - - # Act - result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN") - - # Assert - assert result == builtin_response - assert len(result["recommended_apps"]) == 1 - assert result["recommended_apps"][0]["id"] == "builtin-1" - # Verify fallback was called with en-US (hardcoded) - mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US") - - @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) - @patch("services.recommended_app_service.dify_config", autospec=True) - def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory): - """Test fallback when recommended_apps key is None.""" - # Arrange - mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db" - - # Response with None recommended_apps - none_response = {"recommended_apps": None, "categories": ["test"]} - - # Builtin fallback response - builtin_response = factory.create_recommended_apps_response() - - # Mock db retrieval instance (returns None) - mock_db_instance = MagicMock() - mock_db_instance.get_recommended_apps_and_categories.return_value = none_response - - mock_db_factory = MagicMock() - mock_db_factory.return_value = mock_db_instance - mock_factory_class.get_recommend_app_factory.return_value = mock_db_factory - - # Mock builtin retrieval instance - mock_builtin_instance = MagicMock() - mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response - mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance - - # Act - result = RecommendedAppService.get_recommended_apps_and_categories("en-US") - - # Assert - assert result == builtin_response - mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once() - - @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) - @patch("services.recommended_app_service.dify_config", autospec=True) - def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory): - """Test retrieval with different language codes.""" - # Arrange - mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin" - - languages = ["en-US", "zh-CN", "ja-JP", "fr-FR"] - - for language in languages: - # Create language-specific response - lang_response = factory.create_recommended_apps_response( - recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}] - ) - - # Mock retrieval instance - mock_instance = MagicMock() - mock_instance.get_recommended_apps_and_categories.return_value = lang_response - - mock_factory = MagicMock() - mock_factory.return_value = mock_instance - mock_factory_class.get_recommend_app_factory.return_value = mock_factory - - # Act - result = RecommendedAppService.get_recommended_apps_and_categories(language) - - # Assert - assert result["recommended_apps"][0]["id"] == f"app-{language}" - mock_instance.get_recommended_apps_and_categories.assert_called_with(language) - - @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) - @patch("services.recommended_app_service.dify_config", autospec=True) - def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory): - """Test that correct factory is selected based on mode.""" - # Arrange - modes = ["remote", "builtin", "db"] - - for mode in modes: - mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode - - response = factory.create_recommended_apps_response() - - # Mock retrieval instance - mock_instance = MagicMock() - mock_instance.get_recommended_apps_and_categories.return_value = response - - mock_factory = MagicMock() - mock_factory.return_value = mock_instance - mock_factory_class.get_recommend_app_factory.return_value = mock_factory - - # Act - RecommendedAppService.get_recommended_apps_and_categories("en-US") - - # Assert - mock_factory_class.get_recommend_app_factory.assert_called_with(mode) - - -class TestRecommendedAppServiceGetDetail: - """Test get_recommend_app_detail operations.""" - - @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) - @patch("services.recommended_app_service.dify_config", autospec=True) - def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory): - """Test successful retrieval of app detail.""" - # Arrange - mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" - app_id = "app-123" - - expected_detail = factory.create_app_detail_response( - app_id=app_id, - name="Productivity App", - description="A great productivity app", - category="productivity", - ) - - # Mock retrieval instance - mock_instance = MagicMock() - mock_instance.get_recommend_app_detail.return_value = expected_detail - - mock_factory = MagicMock() - mock_factory.return_value = mock_instance - mock_factory_class.get_recommend_app_factory.return_value = mock_factory - - # Act - result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id)) - - # Assert - assert result == expected_detail - assert result["id"] == app_id - assert result["name"] == "Productivity App" - mock_instance.get_recommend_app_detail.assert_called_once_with(app_id) - - @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) - @patch("services.recommended_app_service.dify_config", autospec=True) - def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory): - """Test app detail retrieval with different factory modes.""" - # Arrange - modes = ["remote", "builtin", "db"] - app_id = "test-app" - - for mode in modes: - mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode - - detail = factory.create_app_detail_response(app_id=app_id, name=f"App from {mode}") - - # Mock retrieval instance - mock_instance = MagicMock() - mock_instance.get_recommend_app_detail.return_value = detail - - mock_factory = MagicMock() - mock_factory.return_value = mock_instance - mock_factory_class.get_recommend_app_factory.return_value = mock_factory - - # Act - result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id)) - - # Assert - assert result["name"] == f"App from {mode}" - mock_factory_class.get_recommend_app_factory.assert_called_with(mode) - - @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) - @patch("services.recommended_app_service.dify_config", autospec=True) - def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory): - """Test that None is returned when app is not found.""" - # Arrange - mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" - app_id = "nonexistent-app" - - # Mock retrieval instance returning None - mock_instance = MagicMock() - mock_instance.get_recommend_app_detail.return_value = None - - mock_factory = MagicMock() - mock_factory.return_value = mock_instance - mock_factory_class.get_recommend_app_factory.return_value = mock_factory - - # Act - result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id)) - - # Assert - assert result is None - mock_instance.get_recommend_app_detail.assert_called_once_with(app_id) - - @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) - @patch("services.recommended_app_service.dify_config", autospec=True) - def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory): - """Test handling of empty dict response.""" - # Arrange - mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin" - app_id = "app-empty" - - # Mock retrieval instance returning empty dict - mock_instance = MagicMock() - mock_instance.get_recommend_app_detail.return_value = {} - - mock_factory = MagicMock() - mock_factory.return_value = mock_instance - mock_factory_class.get_recommend_app_factory.return_value = mock_factory - - # Act - result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id)) - - # Assert - assert result == {} - - @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) - @patch("services.recommended_app_service.dify_config", autospec=True) - def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory): - """Test app detail with complex model configuration.""" - # Arrange - mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" - app_id = "complex-app" - - complex_model_config = { - "provider": "openai", - "model": "gpt-4", - "parameters": { - "temperature": 0.7, - "max_tokens": 2000, - "top_p": 1.0, - }, - } - - expected_detail = factory.create_app_detail_response( - app_id=app_id, - name="Complex App", - model_config=complex_model_config, - workflows=["workflow-1", "workflow-2"], - tools=["tool-1", "tool-2", "tool-3"], - ) - - # Mock retrieval instance - mock_instance = MagicMock() - mock_instance.get_recommend_app_detail.return_value = expected_detail - - mock_factory = MagicMock() - mock_factory.return_value = mock_instance - mock_factory_class.get_recommend_app_factory.return_value = mock_factory - - # Act - result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id)) - - # Assert - assert result["model_config"] == complex_model_config - assert len(result["workflows"]) == 2 - assert len(result["tools"]) == 3 - - -# === Merged from test_recommended_app_service_additional.py === - - -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest - -from services import recommended_app_service as service_module -from services.recommended_app_service import RecommendedAppService - - -def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any]: - return cast(dict[str, Any], result) - - -@pytest.fixture -def mocked_db_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock: - # Arrange - session = MagicMock() - monkeypatch.setattr(service_module, "db", SimpleNamespace(session=session)) - - # Assert - return session - - -def _mock_factory_for_apps( - monkeypatch: pytest.MonkeyPatch, - *, - mode: str, - result: dict[str, Any], - fallback_result: dict[str, Any] | None = None, -) -> tuple[MagicMock, MagicMock]: - retrieval_instance = MagicMock() - retrieval_instance.get_recommended_apps_and_categories.return_value = result - retrieval_factory = MagicMock(return_value=retrieval_instance) - monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False) - monkeypatch.setattr( - service_module.RecommendAppRetrievalFactory, - "get_recommend_app_factory", - MagicMock(return_value=retrieval_factory), - ) - - builtin_instance = MagicMock() - if fallback_result is not None: - builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result - monkeypatch.setattr( - service_module.RecommendAppRetrievalFactory, - "get_buildin_recommend_app_retrieval", - MagicMock(return_value=builtin_instance), - ) - return retrieval_instance, builtin_instance - - -def test_get_recommended_apps_and_categories_should_not_query_trial_table_when_trial_feature_disabled( - monkeypatch: pytest.MonkeyPatch, - mocked_db_session: MagicMock, -) -> None: - # Arrange - expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]} - retrieval_instance, builtin_instance = _mock_factory_for_apps( - monkeypatch, - mode="remote", - result=expected, - ) - monkeypatch.setattr( - service_module.FeatureService, - "get_system_features", - MagicMock(return_value=SimpleNamespace(enable_trial_app=False)), - ) - - # Act - result = RecommendedAppService.get_recommended_apps_and_categories("en-US") - - # Assert - assert result == expected - retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US") - builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called() - mocked_db_session.scalar.assert_not_called() - - -def test_get_recommended_apps_and_categories_should_fallback_and_enrich_can_trial_when_trial_feature_enabled( - monkeypatch: pytest.MonkeyPatch, - mocked_db_session: MagicMock, -) -> None: - # Arrange - remote_result = {"recommended_apps": [], "categories": []} - fallback_result = {"recommended_apps": [{"app_id": "app-1"}, {"app_id": "app-2"}], "categories": ["all"]} - _, builtin_instance = _mock_factory_for_apps( - monkeypatch, - mode="remote", - result=remote_result, - fallback_result=fallback_result, - ) - monkeypatch.setattr( - service_module.FeatureService, - "get_system_features", - MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), - ) - mocked_db_session.scalar.side_effect = [SimpleNamespace(id="trial-app"), None] - - # Act - result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP") - - # Assert - builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US") - assert result["recommended_apps"][0]["can_trial"] is True - assert result["recommended_apps"][1]["can_trial"] is False - assert mocked_db_session.scalar.call_count == 2 - - -@pytest.mark.parametrize( - ("trial_query_result", "expected_can_trial"), - [ - (SimpleNamespace(id="trial"), True), - (None, False), - ], -) -def test_get_recommend_app_detail_should_set_can_trial_when_trial_feature_enabled( - monkeypatch: pytest.MonkeyPatch, - mocked_db_session: MagicMock, - trial_query_result: Any, - expected_can_trial: bool, -) -> None: - # Arrange - detail = {"id": "app-1", "name": "Test App"} - retrieval_instance = MagicMock() - retrieval_instance.get_recommend_app_detail.return_value = detail - retrieval_factory = MagicMock(return_value=retrieval_instance) - monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False) - monkeypatch.setattr( - service_module.RecommendAppRetrievalFactory, - "get_recommend_app_factory", - MagicMock(return_value=retrieval_factory), - ) - monkeypatch.setattr( - service_module.FeatureService, - "get_system_features", - MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), - ) - mocked_db_session.scalar.return_value = trial_query_result - - # Act - result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail("app-1")) - - # Assert - assert result["id"] == "app-1" - assert result["can_trial"] is expected_can_trial - mocked_db_session.scalar.assert_called_once() - - -def test_add_trial_app_record_should_increment_count_when_existing_record_found( - mocked_db_session: MagicMock, -) -> None: - # Arrange - existing_record = SimpleNamespace(count=3) - mocked_db_session.scalar.return_value = existing_record - - # Act - RecommendedAppService.add_trial_app_record("app-1", "account-1") - - # Assert - assert existing_record.count == 4 - mocked_db_session.scalar.assert_called_once() - mocked_db_session.commit.assert_called_once() - mocked_db_session.add.assert_not_called() - - -def test_add_trial_app_record_should_create_new_record_when_no_existing_record( - mocked_db_session: MagicMock, -) -> None: - # Arrange - mocked_db_session.scalar.return_value = None - - # Act - RecommendedAppService.add_trial_app_record("app-2", "account-2") - - # Assert - mocked_db_session.scalar.assert_called_once() - mocked_db_session.add.assert_called_once() - added = mocked_db_session.add.call_args.args[0] - assert added.app_id == "app-2" - assert added.account_id == "account-2" - assert added.count == 1 - mocked_db_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_trigger_provider_service.py b/api/tests/unit_tests/services/test_trigger_provider_service.py index 81a3b181fd..350ff718c1 100644 --- a/api/tests/unit_tests/services/test_trigger_provider_service.py +++ b/api/tests/unit_tests/services/test_trigger_provider_service.py @@ -63,6 +63,12 @@ def mock_session(mocker: MockerFixture) -> MagicMock: mock_session_cm.__enter__.return_value = mock_session_instance mock_session_cm.__exit__.return_value = False mocker.patch("services.trigger.trigger_provider_service.Session", return_value=mock_session_cm) + mock_begin_cm = MagicMock() + mock_begin_cm.__enter__.return_value = mock_session_instance + mock_begin_cm.__exit__.return_value = False + mock_sessionmaker_instance = MagicMock() + mock_sessionmaker_instance.begin.return_value = mock_begin_cm + mocker.patch("services.trigger.trigger_provider_service.sessionmaker", return_value=mock_sessionmaker_instance) return mock_session_instance @@ -212,7 +218,6 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap # Assert assert result["result"] == "success" mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorized_type( @@ -406,7 +411,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache( assert subscription.credentials == {"api_key": "new-key"} assert subscription.credential_expires_at == 100 assert subscription.expires_at == 200 - mock_session.commit.assert_called_once() + mock_delete_cache.assert_called_once() @@ -593,7 +598,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials( assert result == {"result": "success", "expires_at": 12345} assert subscription.credentials == {"access_token": "new"} assert subscription.credential_expires_at == 12345 - mock_session.commit.assert_called_once() + cache.delete.assert_called_once() @@ -664,7 +669,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties( assert result == {"result": "success", "expires_at": 999} assert subscription.properties == {"p": "new-enc"} assert subscription.expires_at == 999 - mock_session.commit.assert_called_once() + prop_cache.delete.assert_called_once() @@ -838,7 +843,6 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w assert fake_model.encrypted_oauth_params == "{}" assert fake_model.enabled is True mock_session.add.assert_called_once_with(fake_model) - mock_session.commit.assert_called_once() def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_cache( @@ -870,7 +874,6 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c assert result == {"result": "success"} assert json.loads(custom_client.encrypted_oauth_params) == {"client_id": "new-id"} cache.delete.assert_called_once() - mock_session.commit.assert_called_once() def test_get_custom_oauth_client_params_should_return_empty_when_record_missing( @@ -921,7 +924,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit( # Assert assert result == {"result": "success"} - mock_session.commit.assert_called_once() @pytest.mark.parametrize("exists", [True, False]) diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 78049182ad..1b5252fc64 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -617,6 +617,20 @@ class _SessionContext: return False +class _SessionmakerContext: + def __init__(self, session: Any) -> None: + self._session = session + + def begin(self) -> "_SessionmakerContext": + return self + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + @pytest.fixture def flask_app() -> Flask: return Flask(__name__) @@ -625,6 +639,7 @@ def flask_app() -> Flask: def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None: monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock())) monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session)) + monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session)) def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger: @@ -1241,7 +1256,6 @@ def test_sync_webhook_relationships_should_create_missing_records_and_delete_sta # Assert assert len(fake_session.added) == 1 assert len(fake_session.deleted) == 1 - assert fake_session.commit_count == 2 redis_set_mock.assert_called_once() redis_delete_mock.assert_called_once() lock.release.assert_called_once() diff --git a/api/uv.lock b/api/uv.lock index 25c6f0b85e..889015783e 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1137,41 +1137,41 @@ wheels = [ [[package]] name = "cryptography" -version = "46.0.6" +version = "46.0.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a4/ba/04b1bd4218cbc58dc90ce967106d51582371b898690f3ae0402876cc4f34/cryptography-46.0.6.tar.gz", hash = "sha256:27550628a518c5c6c903d84f637fbecf287f6cb9ced3804838a1295dc1fd0759", size = 750542, upload-time = "2026-03-25T23:34:53.396Z" } +sdist = { url = "https://files.pythonhosted.org/packages/47/93/ac8f3d5ff04d54bc814e961a43ae5b0b146154c89c61b47bb07557679b18/cryptography-46.0.7.tar.gz", hash = "sha256:e4cfd68c5f3e0bfdad0d38e023239b96a2fe84146481852dffbcca442c245aa5", size = 750652, upload-time = "2026-04-08T01:57:54.692Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/47/23/9285e15e3bc57325b0a72e592921983a701efc1ee8f91c06c5f0235d86d9/cryptography-46.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:64235194bad039a10bb6d2d930ab3323baaec67e2ce36215fd0952fad0930ca8", size = 7176401, upload-time = "2026-03-25T23:33:22.096Z" }, - { url = "https://files.pythonhosted.org/packages/60/f8/e61f8f13950ab6195b31913b42d39f0f9afc7d93f76710f299b5ec286ae6/cryptography-46.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:26031f1e5ca62fcb9d1fcb34b2b60b390d1aacaa15dc8b895a9ed00968b97b30", size = 4275275, upload-time = "2026-03-25T23:33:23.844Z" }, - { url = "https://files.pythonhosted.org/packages/19/69/732a736d12c2631e140be2348b4ad3d226302df63ef64d30dfdb8db7ad1c/cryptography-46.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9a693028b9cbe51b5a1136232ee8f2bc242e4e19d456ded3fa7c86e43c713b4a", size = 4425320, upload-time = "2026-03-25T23:33:25.703Z" }, - { url = "https://files.pythonhosted.org/packages/d4/12/123be7292674abf76b21ac1fc0e1af50661f0e5b8f0ec8285faac18eb99e/cryptography-46.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:67177e8a9f421aa2d3a170c3e56eca4e0128883cf52a071a7cbf53297f18b175", size = 4278082, upload-time = "2026-03-25T23:33:27.423Z" }, - { url = "https://files.pythonhosted.org/packages/5b/ba/d5e27f8d68c24951b0a484924a84c7cdaed7502bac9f18601cd357f8b1d2/cryptography-46.0.6-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:d9528b535a6c4f8ff37847144b8986a9a143585f0540fbcb1a98115b543aa463", size = 4926514, upload-time = "2026-03-25T23:33:29.206Z" }, - { url = "https://files.pythonhosted.org/packages/34/71/1ea5a7352ae516d5512d17babe7e1b87d9db5150b21f794b1377eac1edc0/cryptography-46.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:22259338084d6ae497a19bae5d4c66b7ca1387d3264d1c2c0e72d9e9b6a77b97", size = 4457766, upload-time = "2026-03-25T23:33:30.834Z" }, - { url = "https://files.pythonhosted.org/packages/01/59/562be1e653accee4fdad92c7a2e88fced26b3fdfce144047519bbebc299e/cryptography-46.0.6-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:760997a4b950ff00d418398ad73fbc91aa2894b5c1db7ccb45b4f68b42a63b3c", size = 3986535, upload-time = "2026-03-25T23:33:33.02Z" }, - { url = "https://files.pythonhosted.org/packages/d6/8b/b1ebfeb788bf4624d36e45ed2662b8bd43a05ff62157093c1539c1288a18/cryptography-46.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:3dfa6567f2e9e4c5dceb8ccb5a708158a2a871052fa75c8b78cb0977063f1507", size = 4277618, upload-time = "2026-03-25T23:33:34.567Z" }, - { url = "https://files.pythonhosted.org/packages/dd/52/a005f8eabdb28df57c20f84c44d397a755782d6ff6d455f05baa2785bd91/cryptography-46.0.6-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:cdcd3edcbc5d55757e5f5f3d330dd00007ae463a7e7aa5bf132d1f22a4b62b19", size = 4890802, upload-time = "2026-03-25T23:33:37.034Z" }, - { url = "https://files.pythonhosted.org/packages/ec/4d/8e7d7245c79c617d08724e2efa397737715ca0ec830ecb3c91e547302555/cryptography-46.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:d4e4aadb7fc1f88687f47ca20bb7227981b03afaae69287029da08096853b738", size = 4457425, upload-time = "2026-03-25T23:33:38.904Z" }, - { url = "https://files.pythonhosted.org/packages/1d/5c/f6c3596a1430cec6f949085f0e1a970638d76f81c3ea56d93d564d04c340/cryptography-46.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2b417edbe8877cda9022dde3a008e2deb50be9c407eef034aeeb3a8b11d9db3c", size = 4405530, upload-time = "2026-03-25T23:33:40.842Z" }, - { url = "https://files.pythonhosted.org/packages/7e/c9/9f9cea13ee2dbde070424e0c4f621c091a91ffcc504ffea5e74f0e1daeff/cryptography-46.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:380343e0653b1c9d7e1f55b52aaa2dbb2fdf2730088d48c43ca1c7c0abb7cc2f", size = 4667896, upload-time = "2026-03-25T23:33:42.781Z" }, - { url = "https://files.pythonhosted.org/packages/ad/b5/1895bc0821226f129bc74d00eccfc6a5969e2028f8617c09790bf89c185e/cryptography-46.0.6-cp311-abi3-win32.whl", hash = "sha256:bcb87663e1f7b075e48c3be3ecb5f0b46c8fc50b50a97cf264e7f60242dca3f2", size = 3026348, upload-time = "2026-03-25T23:33:45.021Z" }, - { url = "https://files.pythonhosted.org/packages/c3/f8/c9bcbf0d3e6ad288b9d9aa0b1dee04b063d19e8c4f871855a03ab3a297ab/cryptography-46.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:6739d56300662c468fddb0e5e291f9b4d084bead381667b9e654c7dd81705124", size = 3483896, upload-time = "2026-03-25T23:33:46.649Z" }, - { url = "https://files.pythonhosted.org/packages/c4/cc/f330e982852403da79008552de9906804568ae9230da8432f7496ce02b71/cryptography-46.0.6-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:12cae594e9473bca1a7aceb90536060643128bb274fcea0fc459ab90f7d1ae7a", size = 7162776, upload-time = "2026-03-25T23:34:13.308Z" }, - { url = "https://files.pythonhosted.org/packages/49/b3/dc27efd8dcc4bff583b3f01d4a3943cd8b5821777a58b3a6a5f054d61b79/cryptography-46.0.6-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:639301950939d844a9e1c4464d7e07f902fe9a7f6b215bb0d4f28584729935d8", size = 4270529, upload-time = "2026-03-25T23:34:15.019Z" }, - { url = "https://files.pythonhosted.org/packages/e6/05/e8d0e6eb4f0d83365b3cb0e00eb3c484f7348db0266652ccd84632a3d58d/cryptography-46.0.6-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ed3775295fb91f70b4027aeba878d79b3e55c0b3e97eaa4de71f8f23a9f2eb77", size = 4414827, upload-time = "2026-03-25T23:34:16.604Z" }, - { url = "https://files.pythonhosted.org/packages/2f/97/daba0f5d2dc6d855e2dcb70733c812558a7977a55dd4a6722756628c44d1/cryptography-46.0.6-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8927ccfbe967c7df312ade694f987e7e9e22b2425976ddbf28271d7e58845290", size = 4271265, upload-time = "2026-03-25T23:34:18.586Z" }, - { url = "https://files.pythonhosted.org/packages/89/06/fe1fce39a37ac452e58d04b43b0855261dac320a2ebf8f5260dd55b201a9/cryptography-46.0.6-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:b12c6b1e1651e42ab5de8b1e00dc3b6354fdfd778e7fa60541ddacc27cd21410", size = 4916800, upload-time = "2026-03-25T23:34:20.561Z" }, - { url = "https://files.pythonhosted.org/packages/ff/8a/b14f3101fe9c3592603339eb5d94046c3ce5f7fc76d6512a2d40efd9724e/cryptography-46.0.6-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:063b67749f338ca9c5a0b7fe438a52c25f9526b851e24e6c9310e7195aad3b4d", size = 4448771, upload-time = "2026-03-25T23:34:22.406Z" }, - { url = "https://files.pythonhosted.org/packages/01/b3/0796998056a66d1973fd52ee89dc1bb3b6581960a91ad4ac705f182d398f/cryptography-46.0.6-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:02fad249cb0e090b574e30b276a3da6a149e04ee2f049725b1f69e7b8351ec70", size = 3978333, upload-time = "2026-03-25T23:34:24.281Z" }, - { url = "https://files.pythonhosted.org/packages/c5/3d/db200af5a4ffd08918cd55c08399dc6c9c50b0bc72c00a3246e099d3a849/cryptography-46.0.6-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:7e6142674f2a9291463e5e150090b95a8519b2fb6e6aaec8917dd8d094ce750d", size = 4271069, upload-time = "2026-03-25T23:34:25.895Z" }, - { url = "https://files.pythonhosted.org/packages/d7/18/61acfd5b414309d74ee838be321c636fe71815436f53c9f0334bf19064fa/cryptography-46.0.6-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:456b3215172aeefb9284550b162801d62f5f264a081049a3e94307fe20792cfa", size = 4878358, upload-time = "2026-03-25T23:34:27.67Z" }, - { url = "https://files.pythonhosted.org/packages/8b/65/5bf43286d566f8171917cae23ac6add941654ccf085d739195a4eacf1674/cryptography-46.0.6-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:341359d6c9e68834e204ceaf25936dffeafea3829ab80e9503860dcc4f4dac58", size = 4448061, upload-time = "2026-03-25T23:34:29.375Z" }, - { url = "https://files.pythonhosted.org/packages/e0/25/7e49c0fa7205cf3597e525d156a6bce5b5c9de1fd7e8cb01120e459f205a/cryptography-46.0.6-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9a9c42a2723999a710445bc0d974e345c32adfd8d2fac6d8a251fa829ad31cfb", size = 4399103, upload-time = "2026-03-25T23:34:32.036Z" }, - { url = "https://files.pythonhosted.org/packages/44/46/466269e833f1c4718d6cd496ffe20c56c9c8d013486ff66b4f69c302a68d/cryptography-46.0.6-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6617f67b1606dfd9fe4dbfa354a9508d4a6d37afe30306fe6c101b7ce3274b72", size = 4659255, upload-time = "2026-03-25T23:34:33.679Z" }, - { url = "https://files.pythonhosted.org/packages/0a/09/ddc5f630cc32287d2c953fc5d32705e63ec73e37308e5120955316f53827/cryptography-46.0.6-cp38-abi3-win32.whl", hash = "sha256:7f6690b6c55e9c5332c0b59b9c8a3fb232ebf059094c17f9019a51e9827df91c", size = 3010660, upload-time = "2026-03-25T23:34:35.418Z" }, - { url = "https://files.pythonhosted.org/packages/1b/82/ca4893968aeb2709aacfb57a30dec6fa2ab25b10fa9f064b8882ce33f599/cryptography-46.0.6-cp38-abi3-win_amd64.whl", hash = "sha256:79e865c642cfc5c0b3eb12af83c35c5aeff4fa5c672dc28c43721c2c9fdd2f0f", size = 3471160, upload-time = "2026-03-25T23:34:37.191Z" }, + { url = "https://files.pythonhosted.org/packages/0b/5d/4a8f770695d73be252331e60e526291e3df0c9b27556a90a6b47bccca4c2/cryptography-46.0.7-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:ea42cbe97209df307fdc3b155f1b6fa2577c0defa8f1f7d3be7d31d189108ad4", size = 7179869, upload-time = "2026-04-08T01:56:17.157Z" }, + { url = "https://files.pythonhosted.org/packages/5f/45/6d80dc379b0bbc1f9d1e429f42e4cb9e1d319c7a8201beffd967c516ea01/cryptography-46.0.7-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b36a4695e29fe69215d75960b22577197aca3f7a25b9cf9d165dcfe9d80bc325", size = 4275492, upload-time = "2026-04-08T01:56:19.36Z" }, + { url = "https://files.pythonhosted.org/packages/4a/9a/1765afe9f572e239c3469f2cb429f3ba7b31878c893b246b4b2994ffe2fe/cryptography-46.0.7-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ad9ef796328c5e3c4ceed237a183f5d41d21150f972455a9d926593a1dcb308", size = 4426670, upload-time = "2026-04-08T01:56:21.415Z" }, + { url = "https://files.pythonhosted.org/packages/8f/3e/af9246aaf23cd4ee060699adab1e47ced3f5f7e7a8ffdd339f817b446462/cryptography-46.0.7-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:73510b83623e080a2c35c62c15298096e2a5dc8d51c3b4e1740211839d0dea77", size = 4280275, upload-time = "2026-04-08T01:56:23.539Z" }, + { url = "https://files.pythonhosted.org/packages/0f/54/6bbbfc5efe86f9d71041827b793c24811a017c6ac0fd12883e4caa86b8ed/cryptography-46.0.7-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cbd5fb06b62bd0721e1170273d3f4d5a277044c47ca27ee257025146c34cbdd1", size = 4928402, upload-time = "2026-04-08T01:56:25.624Z" }, + { url = "https://files.pythonhosted.org/packages/2d/cf/054b9d8220f81509939599c8bdbc0c408dbd2bdd41688616a20731371fe0/cryptography-46.0.7-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:420b1e4109cc95f0e5700eed79908cef9268265c773d3a66f7af1eef53d409ef", size = 4459985, upload-time = "2026-04-08T01:56:27.309Z" }, + { url = "https://files.pythonhosted.org/packages/f9/46/4e4e9c6040fb01c7467d47217d2f882daddeb8828f7df800cb806d8a2288/cryptography-46.0.7-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:24402210aa54baae71d99441d15bb5a1919c195398a87b563df84468160a65de", size = 3990652, upload-time = "2026-04-08T01:56:29.095Z" }, + { url = "https://files.pythonhosted.org/packages/36/5f/313586c3be5a2fbe87e4c9a254207b860155a8e1f3cca99f9910008e7d08/cryptography-46.0.7-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8a469028a86f12eb7d2fe97162d0634026d92a21f3ae0ac87ed1c4a447886c83", size = 4279805, upload-time = "2026-04-08T01:56:30.928Z" }, + { url = "https://files.pythonhosted.org/packages/69/33/60dfc4595f334a2082749673386a4d05e4f0cf4df8248e63b2c3437585f2/cryptography-46.0.7-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:9694078c5d44c157ef3162e3bf3946510b857df5a3955458381d1c7cfc143ddb", size = 4892883, upload-time = "2026-04-08T01:56:32.614Z" }, + { url = "https://files.pythonhosted.org/packages/c7/0b/333ddab4270c4f5b972f980adef4faa66951a4aaf646ca067af597f15563/cryptography-46.0.7-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:42a1e5f98abb6391717978baf9f90dc28a743b7d9be7f0751a6f56a75d14065b", size = 4459756, upload-time = "2026-04-08T01:56:34.306Z" }, + { url = "https://files.pythonhosted.org/packages/d2/14/633913398b43b75f1234834170947957c6b623d1701ffc7a9600da907e89/cryptography-46.0.7-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91bbcb08347344f810cbe49065914fe048949648f6bd5c2519f34619142bbe85", size = 4410244, upload-time = "2026-04-08T01:56:35.977Z" }, + { url = "https://files.pythonhosted.org/packages/10/f2/19ceb3b3dc14009373432af0c13f46aa08e3ce334ec6eff13492e1812ccd/cryptography-46.0.7-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5d1c02a14ceb9148cc7816249f64f623fbfee39e8c03b3650d842ad3f34d637e", size = 4674868, upload-time = "2026-04-08T01:56:38.034Z" }, + { url = "https://files.pythonhosted.org/packages/1a/bb/a5c213c19ee94b15dfccc48f363738633a493812687f5567addbcbba9f6f/cryptography-46.0.7-cp311-abi3-win32.whl", hash = "sha256:d23c8ca48e44ee015cd0a54aeccdf9f09004eba9fc96f38c911011d9ff1bd457", size = 3026504, upload-time = "2026-04-08T01:56:39.666Z" }, + { url = "https://files.pythonhosted.org/packages/2b/02/7788f9fefa1d060ca68717c3901ae7fffa21ee087a90b7f23c7a603c32ae/cryptography-46.0.7-cp311-abi3-win_amd64.whl", hash = "sha256:397655da831414d165029da9bc483bed2fe0e75dde6a1523ec2fe63f3c46046b", size = 3488363, upload-time = "2026-04-08T01:56:41.893Z" }, + { url = "https://files.pythonhosted.org/packages/a7/7f/cd42fc3614386bc0c12f0cb3c4ae1fc2bbca5c9662dfed031514911d513d/cryptography-46.0.7-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:462ad5cb1c148a22b2e3bcc5ad52504dff325d17daf5df8d88c17dda1f75f2a4", size = 7165618, upload-time = "2026-04-08T01:57:10.645Z" }, + { url = "https://files.pythonhosted.org/packages/a5/d0/36a49f0262d2319139d2829f773f1b97ef8aef7f97e6e5bd21455e5a8fb5/cryptography-46.0.7-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:84d4cced91f0f159a7ddacad249cc077e63195c36aac40b4150e7a57e84fffe7", size = 4270628, upload-time = "2026-04-08T01:57:12.885Z" }, + { url = "https://files.pythonhosted.org/packages/8a/6c/1a42450f464dda6ffbe578a911f773e54dd48c10f9895a23a7e88b3e7db5/cryptography-46.0.7-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:128c5edfe5e5938b86b03941e94fac9ee793a94452ad1365c9fc3f4f62216832", size = 4415405, upload-time = "2026-04-08T01:57:14.923Z" }, + { url = "https://files.pythonhosted.org/packages/9a/92/4ed714dbe93a066dc1f4b4581a464d2d7dbec9046f7c8b7016f5286329e2/cryptography-46.0.7-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5e51be372b26ef4ba3de3c167cd3d1022934bc838ae9eaad7e644986d2a3d163", size = 4272715, upload-time = "2026-04-08T01:57:16.638Z" }, + { url = "https://files.pythonhosted.org/packages/b7/e6/a26b84096eddd51494bba19111f8fffe976f6a09f132706f8f1bf03f51f7/cryptography-46.0.7-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cdf1a610ef82abb396451862739e3fc93b071c844399e15b90726ef7470eeaf2", size = 4918400, upload-time = "2026-04-08T01:57:19.021Z" }, + { url = "https://files.pythonhosted.org/packages/c7/08/ffd537b605568a148543ac3c2b239708ae0bd635064bab41359252ef88ed/cryptography-46.0.7-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1d25aee46d0c6f1a501adcddb2d2fee4b979381346a78558ed13e50aa8a59067", size = 4450634, upload-time = "2026-04-08T01:57:21.185Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/0cd51dd86ab5b9befe0d031e276510491976c3a80e9f6e31810cce46c4ad/cryptography-46.0.7-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:cdfbe22376065ffcf8be74dc9a909f032df19bc58a699456a21712d6e5eabfd0", size = 3985233, upload-time = "2026-04-08T01:57:22.862Z" }, + { url = "https://files.pythonhosted.org/packages/92/49/819d6ed3a7d9349c2939f81b500a738cb733ab62fbecdbc1e38e83d45e12/cryptography-46.0.7-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:abad9dac36cbf55de6eb49badd4016806b3165d396f64925bf2999bcb67837ba", size = 4271955, upload-time = "2026-04-08T01:57:24.814Z" }, + { url = "https://files.pythonhosted.org/packages/80/07/ad9b3c56ebb95ed2473d46df0847357e01583f4c52a85754d1a55e29e4d0/cryptography-46.0.7-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:935ce7e3cfdb53e3536119a542b839bb94ec1ad081013e9ab9b7cfd478b05006", size = 4879888, upload-time = "2026-04-08T01:57:26.88Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c7/201d3d58f30c4c2bdbe9b03844c291feb77c20511cc3586daf7edc12a47b/cryptography-46.0.7-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:35719dc79d4730d30f1c2b6474bd6acda36ae2dfae1e3c16f2051f215df33ce0", size = 4449961, upload-time = "2026-04-08T01:57:29.068Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ef/649750cbf96f3033c3c976e112265c33906f8e462291a33d77f90356548c/cryptography-46.0.7-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:7bbc6ccf49d05ac8f7d7b5e2e2c33830d4fe2061def88210a126d130d7f71a85", size = 4401696, upload-time = "2026-04-08T01:57:31.029Z" }, + { url = "https://files.pythonhosted.org/packages/41/52/a8908dcb1a389a459a29008c29966c1d552588d4ae6d43f3a1a4512e0ebe/cryptography-46.0.7-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a1529d614f44b863a7b480c6d000fe93b59acee9c82ffa027cfadc77521a9f5e", size = 4664256, upload-time = "2026-04-08T01:57:33.144Z" }, + { url = "https://files.pythonhosted.org/packages/4b/fa/f0ab06238e899cc3fb332623f337a7364f36f4bb3f2534c2bb95a35b132c/cryptography-46.0.7-cp38-abi3-win32.whl", hash = "sha256:f247c8c1a1fb45e12586afbb436ef21ff1e80670b2861a90353d9b025583d246", size = 3013001, upload-time = "2026-04-08T01:57:34.933Z" }, + { url = "https://files.pythonhosted.org/packages/d2/f1/00ce3bde3ca542d1acd8f8cfa38e446840945aa6363f9b74746394b14127/cryptography-46.0.7-cp38-abi3-win_amd64.whl", hash = "sha256:506c4ff91eff4f82bdac7633318a526b1d1309fc07ca76a3ad182cb5b686d6d3", size = 3472985, upload-time = "2026-04-08T01:57:36.714Z" }, ] [[package]] diff --git a/dev/start-docker-compose b/dev/start-docker-compose index aa4f66a6cf..1321c3210f 100755 --- a/dev/start-docker-compose +++ b/dev/start-docker-compose @@ -1,8 +1,8 @@ -#!/usr/bin/env bash -set -euo pipefail - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -ROOT="$(dirname "$SCRIPT_DIR")" - -cd "$ROOT/docker" -docker compose --env-file middleware.env -f docker-compose.middleware.yaml -p dify up -d +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +ROOT="$(dirname "$SCRIPT_DIR")" + +cd "$ROOT/docker" +docker compose --env-file middleware.env -f docker-compose.middleware.yaml -p dify up -d diff --git a/docker/.env.example b/docker/.env.example index 4ff37b7e4f..f6acc19e9b 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -378,6 +378,20 @@ REDIS_USE_CLUSTERS=false REDIS_CLUSTERS= REDIS_CLUSTERS_PASSWORD= +# Redis connection and retry configuration +# max redis retry +REDIS_RETRY_RETRIES=3 +# Base delay (in seconds) for exponential backoff on retries +REDIS_RETRY_BACKOFF_BASE=1.0 +# Cap (in seconds) for exponential backoff on retries +REDIS_RETRY_BACKOFF_CAP=10.0 +# Timeout (in seconds) for Redis socket operations +REDIS_SOCKET_TIMEOUT=5.0 +# Timeout (in seconds) for establishing a Redis connection +REDIS_SOCKET_CONNECT_TIMEOUT=5.0 +# Interval (in seconds) for Redis health checks +REDIS_HEALTH_CHECK_INTERVAL=30 + # ------------------------------ # Celery Configuration # ------------------------------ @@ -1180,6 +1194,14 @@ MAX_ITERATIONS_NUM=99 # The timeout for the text generation in millisecond TEXT_GENERATION_TIMEOUT_MS=60000 +# Enable the experimental vinext runtime shipped in the image. +EXPERIMENTAL_ENABLE_VINEXT=false + +# Allow inline style attributes in Markdown rendering. +# Enable this if your workflows use Jinja2 templates with styled HTML. +# Only recommended for self-hosted deployments with trusted content. +ALLOW_INLINE_STYLES=false + # Allow rendering unsafe URLs which have "data:" scheme. ALLOW_UNSAFE_DATA_SCHEME=false diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index a1c079ce4a..888f96332c 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -162,9 +162,11 @@ services: NEXT_PUBLIC_SOCKET_URL: ${NEXT_PUBLIC_SOCKET_URL:-ws://localhost} SENTRY_DSN: ${WEB_SENTRY_DSN:-} NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} + EXPERIMENTAL_ENABLE_VINEXT: ${EXPERIMENTAL_ENABLE_VINEXT:-false} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} CSP_WHITELIST: ${CSP_WHITELIST:-} ALLOW_EMBED: ${ALLOW_EMBED:-false} + ALLOW_INLINE_STYLES: ${ALLOW_INLINE_STYLES:-false} ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index ccf5ef93d9..17a43a9199 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -101,6 +101,12 @@ x-shared-env: &shared-api-worker-env REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false} REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} + REDIS_RETRY_RETRIES: ${REDIS_RETRY_RETRIES:-3} + REDIS_RETRY_BACKOFF_BASE: ${REDIS_RETRY_BACKOFF_BASE:-1.0} + REDIS_RETRY_BACKOFF_CAP: ${REDIS_RETRY_BACKOFF_CAP:-10.0} + REDIS_SOCKET_TIMEOUT: ${REDIS_SOCKET_TIMEOUT:-5.0} + REDIS_SOCKET_CONNECT_TIMEOUT: ${REDIS_SOCKET_CONNECT_TIMEOUT:-5.0} + REDIS_HEALTH_CHECK_INTERVAL: ${REDIS_HEALTH_CHECK_INTERVAL:-30} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} CELERY_BACKEND: ${CELERY_BACKEND:-redis} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} @@ -511,6 +517,8 @@ x-shared-env: &shared-api-worker-env MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} + EXPERIMENTAL_ENABLE_VINEXT: ${EXPERIMENTAL_ENABLE_VINEXT:-false} + ALLOW_INLINE_STYLES: ${ALLOW_INLINE_STYLES:-false} ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} @@ -873,9 +881,11 @@ services: NEXT_PUBLIC_SOCKET_URL: ${NEXT_PUBLIC_SOCKET_URL:-ws://localhost} SENTRY_DSN: ${WEB_SENTRY_DSN:-} NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} + EXPERIMENTAL_ENABLE_VINEXT: ${EXPERIMENTAL_ENABLE_VINEXT:-false} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} CSP_WHITELIST: ${CSP_WHITELIST:-} ALLOW_EMBED: ${ALLOW_EMBED:-false} + ALLOW_INLINE_STYLES: ${ALLOW_INLINE_STYLES:-false} ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai} diff --git a/e2e/features/smoke/unauthenticated-entry.feature b/e2e/features/smoke/unauthenticated-entry.feature new file mode 100644 index 0000000000..a2783c1cba --- /dev/null +++ b/e2e/features/smoke/unauthenticated-entry.feature @@ -0,0 +1,7 @@ +@smoke @unauthenticated +Feature: Unauthenticated app console entry + Scenario: Redirect to the sign-in page when opening the apps console without logging in + Given I am not signed in + When I open the apps console + Then I should be redirected to the signin page + And I should see the "Sign in" button diff --git a/e2e/features/step-definitions/common/auth.steps.ts b/e2e/features/step-definitions/common/auth.steps.ts index bf03c2d8f4..bed35244c5 100644 --- a/e2e/features/step-definitions/common/auth.steps.ts +++ b/e2e/features/step-definitions/common/auth.steps.ts @@ -9,3 +9,10 @@ Given('I am signed in as the default E2E admin', async function (this: DifyWorld 'text/plain', ) }) + +Given('I am not signed in', async function (this: DifyWorld) { + this.attach( + 'Using a clean browser context without the shared authenticated storage state.', + 'text/plain', + ) +}) diff --git a/e2e/features/step-definitions/common/navigation.steps.ts b/e2e/features/step-definitions/common/navigation.steps.ts index b18ff035fa..28e6953d65 100644 --- a/e2e/features/step-definitions/common/navigation.steps.ts +++ b/e2e/features/step-definitions/common/navigation.steps.ts @@ -10,6 +10,10 @@ Then('I should stay on the apps console', async function (this: DifyWorld) { await expect(this.getPage()).toHaveURL(/\/apps(?:\?.*)?$/) }) +Then('I should be redirected to the signin page', async function (this: DifyWorld) { + await expect(this.getPage()).toHaveURL(/\/signin(?:\?.*)?$/) +}) + Then('I should see the {string} button', async function (this: DifyWorld, label: string) { await expect(this.getPage().getByRole('button', { name: label })).toBeVisible() }) diff --git a/e2e/features/support/hooks.ts b/e2e/features/support/hooks.ts index a6862d79f5..9e8c025ef8 100644 --- a/e2e/features/support/hooks.ts +++ b/e2e/features/support/hooks.ts @@ -46,7 +46,11 @@ BeforeAll(async () => { Before(async function (this: DifyWorld, { pickle }) { if (!browser) throw new Error('Shared Playwright browser is not available.') - await this.startAuthenticatedSession(browser) + const isUnauthenticatedScenario = pickle.tags.some((tag) => tag.name === '@unauthenticated') + + if (isUnauthenticatedScenario) await this.startUnauthenticatedSession(browser) + else await this.startAuthenticatedSession(browser) + this.scenarioStartedAt = Date.now() const tags = pickle.tags.map((tag) => tag.name).join(' ') diff --git a/e2e/features/support/world.ts b/e2e/features/support/world.ts index 15ab8daf16..bf63199107 100644 --- a/e2e/features/support/world.ts +++ b/e2e/features/support/world.ts @@ -25,12 +25,12 @@ export class DifyWorld extends World { this.pageErrors = [] } - async startAuthenticatedSession(browser: Browser) { + async startSession(browser: Browser, authenticated: boolean) { this.resetScenarioState() this.context = await browser.newContext({ baseURL, locale: defaultLocale, - storageState: authStatePath, + ...(authenticated ? { storageState: authStatePath } : {}), }) this.context.setDefaultTimeout(30_000) this.page = await this.context.newPage() @@ -44,6 +44,14 @@ export class DifyWorld extends World { }) } + async startAuthenticatedSession(browser: Browser) { + await this.startSession(browser, true) + } + + async startUnauthenticatedSession(browser: Browser) { + await this.startSession(browser, false) + } + getPage() { if (!this.page) throw new Error('Playwright page has not been initialized for this scenario.') diff --git a/e2e/package.json b/e2e/package.json index 0ee2afff7f..925418f223 100644 --- a/e2e/package.json +++ b/e2e/package.json @@ -19,6 +19,7 @@ "@types/node": "catalog:", "tsx": "catalog:", "typescript": "catalog:", + "vite": "catalog:", "vite-plus": "catalog:" } } diff --git a/package.json b/package.json index ce3180214b..736a354ef7 100644 --- a/package.json +++ b/package.json @@ -5,6 +5,7 @@ "prepare": "vp config" }, "devDependencies": { + "vite": "catalog:", "vite-plus": "catalog:" }, "engines": { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c44833c251..af376af19b 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -249,6 +249,9 @@ catalogs: class-variance-authority: specifier: 0.7.1 version: 0.7.1 + client-only: + specifier: 0.0.1 + version: 0.0.1 clsx: specifier: 2.1.1 version: 2.1.1 @@ -324,9 +327,6 @@ catalogs: fast-deep-equal: specifier: 3.1.3 version: 3.1.3 - foxact: - specifier: 0.3.0 - version: 0.3.0 happy-dom: specifier: 20.8.9 version: 20.8.9 @@ -571,9 +571,12 @@ importers: .: devDependencies: + vite: + specifier: npm:@voidzero-dev/vite-plus-core@0.1.16 + version: '@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)' vite-plus: specifier: 'catalog:' - version: 0.1.16(@types/node@25.5.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3) + version: 0.1.16(@types/node@25.5.2)(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3) e2e: devDependencies: @@ -592,9 +595,12 @@ importers: typescript: specifier: 'catalog:' version: 6.0.2 + vite: + specifier: npm:@voidzero-dev/vite-plus-core@0.1.16 + version: '@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)' vite-plus: specifier: 'catalog:' - version: 0.1.16(@types/node@25.5.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3) + version: 0.1.16(@types/node@25.5.2)(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3) packages/iconify-collections: devDependencies: @@ -618,19 +624,22 @@ importers: version: 8.58.1(eslint@10.2.0(jiti@2.6.1))(typescript@6.0.2) '@vitest/coverage-v8': specifier: 'catalog:' - version: 4.1.3(@voidzero-dev/vite-plus-test@0.1.16(@types/node@25.5.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3)) + version: 4.1.3(@voidzero-dev/vite-plus-test@0.1.16(@types/node@25.5.2)(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)) eslint: specifier: 'catalog:' version: 10.2.0(jiti@2.6.1) typescript: specifier: 'catalog:' version: 6.0.2 + vite: + specifier: npm:@voidzero-dev/vite-plus-core@0.1.16 + version: '@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)' vite-plus: specifier: 'catalog:' - version: 0.1.16(@types/node@25.5.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3) + version: 0.1.16(@types/node@25.5.2)(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3) vitest: specifier: npm:@voidzero-dev/vite-plus-test@0.1.16 - version: '@voidzero-dev/vite-plus-test@0.1.16(@types/node@25.5.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3)' + version: '@voidzero-dev/vite-plus-test@0.1.16(@types/node@25.5.2)(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)' web: dependencies: @@ -730,6 +739,9 @@ importers: class-variance-authority: specifier: 'catalog:' version: 0.7.1 + client-only: + specifier: 'catalog:' + version: 0.0.1 clsx: specifier: 'catalog:' version: 2.1.1 @@ -775,9 +787,6 @@ importers: fast-deep-equal: specifier: 'catalog:' version: 3.1.3 - foxact: - specifier: 'catalog:' - version: 0.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4) hast-util-to-jsx-runtime: specifier: 'catalog:' version: 2.3.6 @@ -2520,9 +2529,6 @@ packages: '@oxc-project/types@0.121.0': resolution: {integrity: sha512-CGtOARQb9tyv7ECgdAlFxi0Fv7lmzvmlm2rpD/RdijOO9rfk/JvB1CjT8EnoD+tjna/IYgKKw3IV7objRb+aYw==} - '@oxc-project/types@0.122.0': - resolution: {integrity: sha512-oLAl5kBpV4w69UtFZ9xqcmTi+GENWOcPF7FCrczTiBbmC0ibXxCwyvZGbO39rCVEuLGAZM84DH0pUIyyv/YJzA==} - '@oxc-project/types@0.123.0': resolution: {integrity: sha512-YtECP/y8Mj1lSHiUWGSRzy/C6teUKlS87dEfuVKT09LgQbUsBW1rNg+MiJ4buGu3yuADV60gbIvo9/HplA56Ew==} @@ -3290,104 +3296,6 @@ packages: resolution: {integrity: sha512-UuBOt7BOsKVOkFXRe4Ypd/lADuNIfqJXv8GvHqtXaTYXPPKkj2nS2zPllVsrtRjcomDhIJVBnZwfmlI222WH8g==} engines: {node: '>=14.0.0'} - '@rolldown/binding-android-arm64@1.0.0-rc.12': - resolution: {integrity: sha512-pv1y2Fv0JybcykuiiD3qBOBdz6RteYojRFY1d+b95WVuzx211CRh+ytI/+9iVyWQ6koTh5dawe4S/yRfOFjgaA==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [arm64] - os: [android] - - '@rolldown/binding-darwin-arm64@1.0.0-rc.12': - resolution: {integrity: sha512-cFYr6zTG/3PXXF3pUO+umXxt1wkRK/0AYT8lDwuqvRC+LuKYWSAQAQZjCWDQpAH172ZV6ieYrNnFzVVcnSflAg==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [arm64] - os: [darwin] - - '@rolldown/binding-darwin-x64@1.0.0-rc.12': - resolution: {integrity: sha512-ZCsYknnHzeXYps0lGBz8JrF37GpE9bFVefrlmDrAQhOEi4IOIlcoU1+FwHEtyXGx2VkYAvhu7dyBf75EJQffBw==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [x64] - os: [darwin] - - '@rolldown/binding-freebsd-x64@1.0.0-rc.12': - resolution: {integrity: sha512-dMLeprcVsyJsKolRXyoTH3NL6qtsT0Y2xeuEA8WQJquWFXkEC4bcu1rLZZSnZRMtAqwtrF/Ib9Ddtpa/Gkge9Q==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [x64] - os: [freebsd] - - '@rolldown/binding-linux-arm-gnueabihf@1.0.0-rc.12': - resolution: {integrity: sha512-YqWjAgGC/9M1lz3GR1r1rP79nMgo3mQiiA+Hfo+pvKFK1fAJ1bCi0ZQVh8noOqNacuY1qIcfyVfP6HoyBRZ85Q==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [arm] - os: [linux] - - '@rolldown/binding-linux-arm64-gnu@1.0.0-rc.12': - resolution: {integrity: sha512-/I5AS4cIroLpslsmzXfwbe5OmWvSsrFuEw3mwvbQ1kDxJ822hFHIx+vsN/TAzNVyepI/j/GSzrtCIwQPeKCLIg==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [arm64] - os: [linux] - libc: [glibc] - - '@rolldown/binding-linux-arm64-musl@1.0.0-rc.12': - resolution: {integrity: sha512-V6/wZztnBqlx5hJQqNWwFdxIKN0m38p8Jas+VoSfgH54HSj9tKTt1dZvG6JRHcjh6D7TvrJPWFGaY9UBVOaWPw==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [arm64] - os: [linux] - libc: [musl] - - '@rolldown/binding-linux-ppc64-gnu@1.0.0-rc.12': - resolution: {integrity: sha512-AP3E9BpcUYliZCxa3w5Kwj9OtEVDYK6sVoUzy4vTOJsjPOgdaJZKFmN4oOlX0Wp0RPV2ETfmIra9x1xuayFB7g==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [ppc64] - os: [linux] - libc: [glibc] - - '@rolldown/binding-linux-s390x-gnu@1.0.0-rc.12': - resolution: {integrity: sha512-nWwpvUSPkoFmZo0kQazZYOrT7J5DGOJ/+QHHzjvNlooDZED8oH82Yg67HvehPPLAg5fUff7TfWFHQS8IV1n3og==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [s390x] - os: [linux] - libc: [glibc] - - '@rolldown/binding-linux-x64-gnu@1.0.0-rc.12': - resolution: {integrity: sha512-RNrafz5bcwRy+O9e6P8Z/OCAJW/A+qtBczIqVYwTs14pf4iV1/+eKEjdOUta93q2TsT/FI0XYDP3TCky38LMAg==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [x64] - os: [linux] - libc: [glibc] - - '@rolldown/binding-linux-x64-musl@1.0.0-rc.12': - resolution: {integrity: sha512-Jpw/0iwoKWx3LJ2rc1yjFrj+T7iHZn2JDg1Yny1ma0luviFS4mhAIcd1LFNxK3EYu3DHWCps0ydXQ5i/rrJ2ig==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [x64] - os: [linux] - libc: [musl] - - '@rolldown/binding-openharmony-arm64@1.0.0-rc.12': - resolution: {integrity: sha512-vRugONE4yMfVn0+7lUKdKvN4D5YusEiPilaoO2sgUWpCvrncvWgPMzK00ZFFJuiPgLwgFNP5eSiUlv2tfc+lpA==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [arm64] - os: [openharmony] - - '@rolldown/binding-wasm32-wasi@1.0.0-rc.12': - resolution: {integrity: sha512-ykGiLr/6kkiHc0XnBfmFJuCjr5ZYKKofkx+chJWDjitX+KsJuAmrzWhwyOMSHzPhzOHOy7u9HlFoa5MoAOJ/Zg==} - engines: {node: '>=14.0.0'} - cpu: [wasm32] - - '@rolldown/binding-win32-arm64-msvc@1.0.0-rc.12': - resolution: {integrity: sha512-5eOND4duWkwx1AzCxadcOrNeighiLwMInEADT0YM7xeEOOFcovWZCq8dadXgcRHSf3Ulh1kFo/qvzoFiCLOL1Q==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [arm64] - os: [win32] - - '@rolldown/binding-win32-x64-msvc@1.0.0-rc.12': - resolution: {integrity: sha512-PyqoipaswDLAZtot351MLhrlrh6lcZPo2LSYE+VDxbVk24LVKAGOuE4hb8xZQmrPAuEtTZW8E6D2zc5EUZX4Lw==} - engines: {node: ^20.19.0 || >=22.12.0} - cpu: [x64] - os: [win32] - - '@rolldown/pluginutils@1.0.0-rc.12': - resolution: {integrity: sha512-HHMwmarRKvoFsJorqYlFeFRzXZqCt2ETQlEDOb9aqssrnVBB1/+xgTGtuTrIk5vzLNX1MjMtTf7W9z3tsSbrxw==} - '@rolldown/pluginutils@1.0.0-rc.13': resolution: {integrity: sha512-3ngTAv6F/Py35BsYbeeLeecvhMKdsKm4AoOETVhAA+Qc8nrA2I0kF7oa93mE9qnIurngOSpMnQ0x2nQY2FPviA==} @@ -5979,9 +5887,6 @@ packages: resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==} engines: {node: '>=0.10.0'} - event-target-bus@1.0.0: - resolution: {integrity: sha512-uPcWKbj/BJU3Tbw9XqhHqET4/LBOhvv3/SJWr7NksxA6TC5YqBpaZgawE9R+WpYFCBFSAE4Vun+xQS6w4ABdlA==} - events@3.3.0: resolution: {integrity: sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==} engines: {node: '>=0.8.x'} @@ -6094,17 +5999,6 @@ packages: engines: {node: '>=18.3.0'} hasBin: true - foxact@0.3.0: - resolution: {integrity: sha512-CSlMlC0KlKQQEO83iLeQCLuT1V0OqnMWj7mjLstIDV8baMe1w4F7z3cz3/T+6Z8W12jqkQj07rwlw4Gi39knGg==} - peerDependencies: - react: '*' - react-dom: '*' - peerDependenciesMeta: - react: - optional: true - react-dom: - optional: true - fs-constants@1.0.0: resolution: {integrity: sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==} @@ -7744,11 +7638,6 @@ packages: robust-predicates@3.0.3: resolution: {integrity: sha512-NS3levdsRIUOmiJ8FZWCP7LG3QpJyrs/TE0Zpf1yvZu8cAJJ6QMW92H1c7kWpdIHo8RvmLxN/o2JXTKHp74lUA==} - rolldown@1.0.0-rc.12: - resolution: {integrity: sha512-yP4USLIMYrwpPHEFB5JGH1uxhcslv6/hL0OyvTuY+3qlOSJvZ7ntYnoWpehBxufkgN0cvXxppuTu5hHa/zPh+A==} - engines: {node: ^20.19.0 || >=22.12.0} - hasBin: true - rollup@4.59.0: resolution: {integrity: sha512-2oMpl67a3zCH9H79LeMcbDhXW/UmWG/y2zuqnF2jQq5uq9TbM9TVyXvA4+t+ne2IIkBdrLpAaRQAvo7YI/Yyeg==} engines: {node: '>=18.0.0', npm: '>=8.0.0'} @@ -7823,9 +7712,6 @@ packages: resolution: {integrity: sha512-OwrZRZAfhHww0WEnKHDY8OM0U/Qs8OTfIDWhUD4BLpNJUfXK4cGmjiagGze086m+mhI+V2nD0gfbHEnJjb9STA==} engines: {node: '>=10'} - server-only@0.0.1: - resolution: {integrity: sha512-qepMx2JxAa5jjfzxG79yPPq+8BuFToHd1hm7kI+Z4zAq1ftQiP7HcxMhDDItrbtwVeLg/cY2JnKnrcFkmiswNA==} - sharp@0.34.5: resolution: {integrity: sha512-Ou9I5Ft9WNcCbXrU9cMgPBcCK8LiwLqcbywW3t4oDV37n1pzpuNLsYiAV8eODnjbtQlSDwZ2cUEeQz4E54Hltg==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} @@ -8494,49 +8380,6 @@ packages: peerDependencies: vite: '*' - vite@8.0.3: - resolution: {integrity: sha512-B9ifbFudT1TFhfltfaIPgjo9Z3mDynBTJSUYxTjOQruf/zHH+ezCQKcoqO+h7a9Pw9Nm/OtlXAiGT1axBgwqrQ==} - engines: {node: ^20.19.0 || >=22.12.0} - hasBin: true - peerDependencies: - '@types/node': ^20.19.0 || >=22.12.0 - '@vitejs/devtools': ^0.1.0 - esbuild: 0.27.2 - jiti: '>=1.21.0' - less: ^4.0.0 - sass: ^1.70.0 - sass-embedded: ^1.70.0 - stylus: '>=0.54.8' - sugarss: ^5.0.0 - terser: ^5.16.0 - tsx: ^4.8.1 - yaml: 2.8.3 - peerDependenciesMeta: - '@types/node': - optional: true - '@vitejs/devtools': - optional: true - esbuild: - optional: true - jiti: - optional: true - less: - optional: true - sass: - optional: true - sass-embedded: - optional: true - stylus: - optional: true - sugarss: - optional: true - terser: - optional: true - tsx: - optional: true - yaml: - optional: true - vitefu@1.1.3: resolution: {integrity: sha512-ub4okH7Z5KLjb6hDyjqrGXqWtWvoYdU3IGm/NorpgHncKoLTCfRIbvlhBm7r0YstIaQRYlp4yEbFqDcKSzXSSg==} peerDependencies: @@ -10309,8 +10152,6 @@ snapshots: '@oxc-project/types@0.121.0': {} - '@oxc-project/types@0.122.0': {} - '@oxc-project/types@0.123.0': {} '@oxc-resolver/binding-android-arm-eabi@11.19.1': @@ -10875,58 +10716,6 @@ snapshots: '@rgrove/parse-xml@4.2.0': {} - '@rolldown/binding-android-arm64@1.0.0-rc.12': - optional: true - - '@rolldown/binding-darwin-arm64@1.0.0-rc.12': - optional: true - - '@rolldown/binding-darwin-x64@1.0.0-rc.12': - optional: true - - '@rolldown/binding-freebsd-x64@1.0.0-rc.12': - optional: true - - '@rolldown/binding-linux-arm-gnueabihf@1.0.0-rc.12': - optional: true - - '@rolldown/binding-linux-arm64-gnu@1.0.0-rc.12': - optional: true - - '@rolldown/binding-linux-arm64-musl@1.0.0-rc.12': - optional: true - - '@rolldown/binding-linux-ppc64-gnu@1.0.0-rc.12': - optional: true - - '@rolldown/binding-linux-s390x-gnu@1.0.0-rc.12': - optional: true - - '@rolldown/binding-linux-x64-gnu@1.0.0-rc.12': - optional: true - - '@rolldown/binding-linux-x64-musl@1.0.0-rc.12': - optional: true - - '@rolldown/binding-openharmony-arm64@1.0.0-rc.12': - optional: true - - '@rolldown/binding-wasm32-wasi@1.0.0-rc.12(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)': - dependencies: - '@napi-rs/wasm-runtime': 1.1.2(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) - transitivePeerDependencies: - - '@emnapi/core' - - '@emnapi/runtime' - optional: true - - '@rolldown/binding-win32-arm64-msvc@1.0.0-rc.12': - optional: true - - '@rolldown/binding-win32-x64-msvc@1.0.0-rc.12': - optional: true - - '@rolldown/pluginutils@1.0.0-rc.12': {} - '@rolldown/pluginutils@1.0.0-rc.13': {} '@rolldown/pluginutils@1.0.0-rc.7': {} @@ -12151,20 +11940,6 @@ snapshots: tinyrainbow: 3.1.0 vitest: '@voidzero-dev/vite-plus-test@0.1.16(@types/node@25.5.2)(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)' - '@vitest/coverage-v8@4.1.3(@voidzero-dev/vite-plus-test@0.1.16(@types/node@25.5.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3))': - dependencies: - '@bcoe/v8-coverage': 1.0.2 - '@vitest/utils': 4.1.3 - ast-v8-to-istanbul: 1.0.0 - istanbul-lib-coverage: 3.2.2 - istanbul-lib-report: 3.0.1 - istanbul-reports: 3.2.0 - magicast: 0.5.2 - obug: 2.1.1 - std-env: 4.0.0 - tinyrainbow: 3.1.0 - vitest: '@voidzero-dev/vite-plus-test@0.1.16(@types/node@25.5.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3)' - '@vitest/eslint-plugin@1.6.14(@typescript-eslint/eslint-plugin@8.58.1(@typescript-eslint/parser@8.58.1(eslint@10.2.0(jiti@2.6.1))(typescript@6.0.2))(eslint@10.2.0(jiti@2.6.1))(typescript@6.0.2))(@voidzero-dev/vite-plus-test@0.1.16(@types/node@25.5.2)(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(eslint@10.2.0(jiti@2.6.1))(typescript@6.0.2)': dependencies: '@typescript-eslint/scope-manager': 8.58.1 @@ -12283,46 +12058,6 @@ snapshots: - utf-8-validate - yaml - '@voidzero-dev/vite-plus-test@0.1.16(@types/node@25.5.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3)': - dependencies: - '@standard-schema/spec': 1.1.0 - '@types/chai': 5.2.3 - '@voidzero-dev/vite-plus-core': 0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3) - es-module-lexer: 1.7.0 - obug: 2.1.1 - pixelmatch: 7.1.0 - pngjs: 7.0.0 - sirv: 3.0.2 - std-env: 4.0.0 - tinybench: 2.9.0 - tinyexec: 1.0.4 - tinyglobby: 0.2.15 - vite: 8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3) - ws: 8.20.0 - optionalDependencies: - '@types/node': 25.5.2 - happy-dom: 20.8.9 - transitivePeerDependencies: - - '@arethetypeswrong/core' - - '@tsdown/css' - - '@tsdown/exe' - - '@vitejs/devtools' - - bufferutil - - esbuild - - jiti - - less - - publint - - sass - - sass-embedded - - stylus - - sugarss - - terser - - tsx - - typescript - - unplugin-unused - - utf-8-validate - - yaml - '@voidzero-dev/vite-plus-win32-arm64-msvc@0.1.16': optional: true @@ -13856,8 +13591,6 @@ snapshots: esutils@2.0.3: {} - event-target-bus@1.0.0: {} - events@3.3.0: {} expand-template@2.0.3: @@ -13965,15 +13698,6 @@ snapshots: dependencies: fd-package-json: 2.0.0 - foxact@0.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4): - dependencies: - client-only: 0.0.1 - event-target-bus: 1.0.0 - server-only: 0.0.1 - optionalDependencies: - react: 19.2.4 - react-dom: 19.2.4(react@19.2.4) - fs-constants@1.0.0: optional: true @@ -16104,30 +15828,6 @@ snapshots: robust-predicates@3.0.3: {} - rolldown@1.0.0-rc.12(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1): - dependencies: - '@oxc-project/types': 0.122.0 - '@rolldown/pluginutils': 1.0.0-rc.12 - optionalDependencies: - '@rolldown/binding-android-arm64': 1.0.0-rc.12 - '@rolldown/binding-darwin-arm64': 1.0.0-rc.12 - '@rolldown/binding-darwin-x64': 1.0.0-rc.12 - '@rolldown/binding-freebsd-x64': 1.0.0-rc.12 - '@rolldown/binding-linux-arm-gnueabihf': 1.0.0-rc.12 - '@rolldown/binding-linux-arm64-gnu': 1.0.0-rc.12 - '@rolldown/binding-linux-arm64-musl': 1.0.0-rc.12 - '@rolldown/binding-linux-ppc64-gnu': 1.0.0-rc.12 - '@rolldown/binding-linux-s390x-gnu': 1.0.0-rc.12 - '@rolldown/binding-linux-x64-gnu': 1.0.0-rc.12 - '@rolldown/binding-linux-x64-musl': 1.0.0-rc.12 - '@rolldown/binding-openharmony-arm64': 1.0.0-rc.12 - '@rolldown/binding-wasm32-wasi': 1.0.0-rc.12(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) - '@rolldown/binding-win32-arm64-msvc': 1.0.0-rc.12 - '@rolldown/binding-win32-x64-msvc': 1.0.0-rc.12 - transitivePeerDependencies: - - '@emnapi/core' - - '@emnapi/runtime' - rollup@4.59.0: dependencies: '@types/estree': 1.0.8 @@ -16233,8 +15933,6 @@ snapshots: seroval@1.5.1: {} - server-only@0.0.1: {} - sharp@0.34.5: dependencies: '@img/colour': 1.1.0 @@ -16967,51 +16665,6 @@ snapshots: - vite - yaml - vite-plus@0.1.16(@types/node@25.5.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3): - dependencies: - '@oxc-project/types': 0.123.0 - '@voidzero-dev/vite-plus-core': 0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3) - '@voidzero-dev/vite-plus-test': 0.1.16(@types/node@25.5.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3) - oxfmt: 0.43.0 - oxlint: 1.58.0(oxlint-tsgolint@0.20.0) - oxlint-tsgolint: 0.20.0 - optionalDependencies: - '@voidzero-dev/vite-plus-darwin-arm64': 0.1.16 - '@voidzero-dev/vite-plus-darwin-x64': 0.1.16 - '@voidzero-dev/vite-plus-linux-arm64-gnu': 0.1.16 - '@voidzero-dev/vite-plus-linux-arm64-musl': 0.1.16 - '@voidzero-dev/vite-plus-linux-x64-gnu': 0.1.16 - '@voidzero-dev/vite-plus-linux-x64-musl': 0.1.16 - '@voidzero-dev/vite-plus-win32-arm64-msvc': 0.1.16 - '@voidzero-dev/vite-plus-win32-x64-msvc': 0.1.16 - transitivePeerDependencies: - - '@arethetypeswrong/core' - - '@edge-runtime/vm' - - '@opentelemetry/api' - - '@tsdown/css' - - '@tsdown/exe' - - '@types/node' - - '@vitejs/devtools' - - '@vitest/ui' - - bufferutil - - esbuild - - happy-dom - - jiti - - jsdom - - less - - publint - - sass - - sass-embedded - - stylus - - sugarss - - terser - - tsx - - typescript - - unplugin-unused - - utf-8-validate - - vite - - yaml - vite-tsconfig-paths@5.1.4(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(typescript@6.0.2): dependencies: debug: 4.4.3(supports-color@8.1.1) @@ -17033,25 +16686,6 @@ snapshots: - supports-color - typescript - vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3): - dependencies: - lightningcss: 1.32.0 - picomatch: 4.0.4 - postcss: 8.5.9 - rolldown: 1.0.0-rc.12(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) - tinyglobby: 0.2.15 - optionalDependencies: - '@types/node': 25.5.2 - fsevents: 2.3.3 - jiti: 2.6.1 - sass: 1.98.0 - terser: 5.46.1 - tsx: 4.21.0 - yaml: 2.8.3 - transitivePeerDependencies: - - '@emnapi/core' - - '@emnapi/runtime' - vitefu@1.1.3(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)): optionalDependencies: vite: '@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)' diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 104a7e30a3..d26eee4171 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -129,6 +129,7 @@ catalog: ahooks: 3.9.7 autoprefixer: 10.4.27 class-variance-authority: 0.7.1 + client-only: 0.0.1 clsx: 2.1.1 cmdk: 1.1.1 code-inspector-plugin: 1.5.1 @@ -154,7 +155,6 @@ catalog: eslint-plugin-sonarjs: 4.0.2 eslint-plugin-storybook: 10.3.5 fast-deep-equal: 3.1.3 - foxact: 0.3.0 happy-dom: 20.8.9 hast-util-to-jsx-runtime: 2.3.6 hono: 4.12.12 diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index da9f7353ac..e058edb0ca 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -62,6 +62,7 @@ "@vitest/coverage-v8": "catalog:", "eslint": "catalog:", "typescript": "catalog:", + "vite": "catalog:", "vite-plus": "catalog:", "vitest": "catalog:" } diff --git a/web/.env.example b/web/.env.example index 1c726d8dca..643aba482e 100644 --- a/web/.env.example +++ b/web/.env.example @@ -50,6 +50,9 @@ NEXT_PUBLIC_CSP_WHITELIST= # Default is not allow to embed into iframe to prevent Clickjacking: https://owasp.org/www-community/attacks/Clickjacking NEXT_PUBLIC_ALLOW_EMBED= +# Allow inline style attributes in Markdown rendering (self-hosted opt-in). +NEXT_PUBLIC_ALLOW_INLINE_STYLES=false + # Allow rendering unsafe URLs which have "data:" scheme. NEXT_PUBLIC_ALLOW_UNSAFE_DATA_SCHEME=false diff --git a/web/Dockerfile b/web/Dockerfile index 030651bf27..4971f86f97 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -42,7 +42,7 @@ COPY . . WORKDIR /app/web ENV NODE_OPTIONS="--max-old-space-size=4096" -RUN pnpm build +RUN pnpm build && pnpm build:vinext # production stage @@ -56,6 +56,7 @@ ENV APP_API_URL=http://127.0.0.1:5001 ENV MARKETPLACE_API_URL=https://marketplace.dify.ai ENV MARKETPLACE_URL=https://marketplace.dify.ai ENV PORT=3000 +ENV EXPERIMENTAL_ENABLE_VINEXT=false ENV NEXT_TELEMETRY_DISABLED=1 # set timezone @@ -73,9 +74,10 @@ RUN addgroup -S -g ${dify_uid} dify && \ WORKDIR /app -COPY --from=builder --chown=dify:dify /app/web/public ./web/public -COPY --from=builder --chown=dify:dify /app/web/.next/standalone ./ -COPY --from=builder --chown=dify:dify /app/web/.next/static ./web/.next/static +COPY --from=builder --chown=dify:dify /app/web/public ./targets/next/web/public +COPY --from=builder --chown=dify:dify /app/web/.next/standalone ./targets/next/ +COPY --from=builder --chown=dify:dify /app/web/.next/static ./targets/next/web/.next/static +COPY --from=builder --chown=dify:dify /app/web/dist/standalone ./targets/vinext COPY --chown=dify:dify --chmod=755 web/docker/entrypoint.sh ./entrypoint.sh diff --git a/web/app/components/app/annotation/view-annotation-modal/__tests__/hit-history-no-data.spec.tsx b/web/app/components/app/annotation/view-annotation-modal/__tests__/hit-history-no-data.spec.tsx new file mode 100644 index 0000000000..33a38106d0 --- /dev/null +++ b/web/app/components/app/annotation/view-annotation-modal/__tests__/hit-history-no-data.spec.tsx @@ -0,0 +1,10 @@ +import { render, screen } from '@testing-library/react' +import HitHistoryNoData from '../hit-history-no-data' + +describe('HitHistoryNoData', () => { + it('should render the empty history message', () => { + render() + + expect(screen.getByText('appAnnotation.viewModal.noHitHistory')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/app-access-control/__tests__/access-control-dialog.spec.tsx b/web/app/components/app/app-access-control/__tests__/access-control-dialog.spec.tsx new file mode 100644 index 0000000000..5c7d2f2dc0 --- /dev/null +++ b/web/app/components/app/app-access-control/__tests__/access-control-dialog.spec.tsx @@ -0,0 +1,32 @@ +/* eslint-disable ts/no-explicit-any */ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import AccessControlDialog from '../access-control-dialog' + +describe('AccessControlDialog', () => { + it('should render dialog content when visible', () => { + render( + +
Dialog Content
+
, + ) + + expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByText('Dialog Content')).toBeInTheDocument() + }) + + it('should trigger onClose when clicking the close control', async () => { + const onClose = vi.fn() + render( + +
Dialog Content
+
, + ) + + const closeButton = document.body.querySelector('div.absolute.right-5.top-5') as HTMLElement + fireEvent.click(closeButton) + + await waitFor(() => { + expect(onClose).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/app-access-control/__tests__/access-control-item.spec.tsx b/web/app/components/app/app-access-control/__tests__/access-control-item.spec.tsx new file mode 100644 index 0000000000..b1a862a13c --- /dev/null +++ b/web/app/components/app/app-access-control/__tests__/access-control-item.spec.tsx @@ -0,0 +1,45 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import useAccessControlStore from '@/context/access-control-store' +import { AccessMode } from '@/models/access-control' +import AccessControlItem from '../access-control-item' + +describe('AccessControlItem', () => { + beforeEach(() => { + vi.clearAllMocks() + useAccessControlStore.setState({ + appId: '', + specificGroups: [], + specificMembers: [], + currentMenu: AccessMode.PUBLIC, + selectedGroupsForBreadcrumb: [], + }) + }) + + it('should update current menu when selecting a different access type', () => { + render( + + Organization Only + , + ) + + const option = screen.getByText('Organization Only').parentElement as HTMLElement + fireEvent.click(option) + + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION) + }) + + it('should keep the selected state for the active access type', () => { + useAccessControlStore.setState({ + currentMenu: AccessMode.ORGANIZATION, + }) + + render( + + Organization Only + , + ) + + const option = screen.getByText('Organization Only').parentElement as HTMLElement + expect(option).toHaveClass('border-components-option-card-option-selected-border') + }) +}) diff --git a/web/app/components/app/app-access-control/__tests__/add-member-or-group-pop.spec.tsx b/web/app/components/app/app-access-control/__tests__/add-member-or-group-pop.spec.tsx new file mode 100644 index 0000000000..725b121d30 --- /dev/null +++ b/web/app/components/app/app-access-control/__tests__/add-member-or-group-pop.spec.tsx @@ -0,0 +1,130 @@ +import type { AccessControlAccount, AccessControlGroup, Subject } from '@/models/access-control' +import { fireEvent, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import useAccessControlStore from '@/context/access-control-store' +import { SubjectType } from '@/models/access-control' +import AddMemberOrGroupDialog from '../add-member-or-group-pop' + +const mockUseSearchForWhiteListCandidates = vi.fn() +const intersectionObserverMocks = vi.hoisted(() => ({ + callback: null as null | ((entries: Array<{ isIntersecting: boolean }>) => void), +})) + +vi.mock('@/context/app-context', () => ({ + useSelector: (selector: (value: { userProfile: { email: string } }) => T) => selector({ + userProfile: { + email: 'member@example.com', + }, + }), +})) + +vi.mock('@/service/access-control', () => ({ + useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args), +})) + +const createGroup = (overrides: Partial = {}): AccessControlGroup => ({ + id: 'group-1', + name: 'Group One', + groupSize: 5, + ...overrides, +} as AccessControlGroup) + +const createMember = (overrides: Partial = {}): AccessControlAccount => ({ + id: 'member-1', + name: 'Member One', + email: 'member@example.com', + avatar: '', + avatarUrl: '', + ...overrides, +} as AccessControlAccount) + +describe('AddMemberOrGroupDialog', () => { + const baseGroup = createGroup() + const baseMember = createMember() + const groupSubject: Subject = { + subjectId: baseGroup.id, + subjectType: SubjectType.GROUP, + groupData: baseGroup, + } as Subject + const memberSubject: Subject = { + subjectId: baseMember.id, + subjectType: SubjectType.ACCOUNT, + accountData: baseMember, + } as Subject + + beforeAll(() => { + class MockIntersectionObserver { + constructor(callback: (entries: Array<{ isIntersecting: boolean }>) => void) { + intersectionObserverMocks.callback = callback + } + + observe = vi.fn(() => undefined) + disconnect = vi.fn(() => undefined) + unobserve = vi.fn(() => undefined) + } + + // @ts-expect-error test DOM typings do not guarantee IntersectionObserver here + globalThis.IntersectionObserver = MockIntersectionObserver + }) + + beforeEach(() => { + vi.clearAllMocks() + useAccessControlStore.setState({ + appId: 'app-1', + specificGroups: [], + specificMembers: [], + currentMenu: SubjectType.GROUP as never, + selectedGroupsForBreadcrumb: [], + }) + mockUseSearchForWhiteListCandidates.mockReturnValue({ + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + data: { + pages: [{ currPage: 1, subjects: [groupSubject, memberSubject], hasMore: false }], + }, + }) + }) + + it('should open the search popover and display candidates', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByText('common.operation.add')) + + expect(screen.getByPlaceholderText('app.accessControlDialog.operateGroupAndMember.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByText(baseGroup.name)).toBeInTheDocument() + expect(screen.getByText(baseMember.name)).toBeInTheDocument() + }) + + it('should allow expanding groups and selecting members', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByText('common.operation.add')) + await user.click(screen.getByText('app.accessControlDialog.operateGroupAndMember.expand')) + + expect(useAccessControlStore.getState().selectedGroupsForBreadcrumb).toEqual([baseGroup]) + + const memberCheckbox = screen.getByText(baseMember.name).parentElement?.previousElementSibling as HTMLElement + fireEvent.click(memberCheckbox) + + expect(useAccessControlStore.getState().specificMembers).toEqual([baseMember]) + }) + + it('should show the empty state when no candidates are returned', async () => { + mockUseSearchForWhiteListCandidates.mockReturnValue({ + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + data: { pages: [] }, + }) + + const user = userEvent.setup() + render() + + await user.click(screen.getByText('common.operation.add')) + + expect(screen.getByText('app.accessControlDialog.operateGroupAndMember.noResult')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/app-access-control/__tests__/index.spec.tsx b/web/app/components/app/app-access-control/__tests__/index.spec.tsx new file mode 100644 index 0000000000..f2fa09f98a --- /dev/null +++ b/web/app/components/app/app-access-control/__tests__/index.spec.tsx @@ -0,0 +1,121 @@ +/* eslint-disable ts/no-explicit-any */ +import type { App } from '@/types/app' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { toast } from '@/app/components/base/ui/toast' +import useAccessControlStore from '@/context/access-control-store' +import { AccessMode } from '@/models/access-control' +import AccessControl from '../index' + +const mockMutateAsync = vi.fn() +const mockUseUpdateAccessMode = vi.fn(() => ({ + isPending: false, + mutateAsync: mockMutateAsync, +})) +const mockUseAppWhiteListSubjects = vi.fn() +const mockUseSearchForWhiteListCandidates = vi.fn() +let mockWebappAuth = { + enabled: true, + allow_sso: true, + allow_email_password_login: false, + allow_email_code_login: false, +} + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: { webapp_auth: typeof mockWebappAuth } }) => unknown) => selector({ + systemFeatures: { + webapp_auth: mockWebappAuth, + }, + }), +})) + +vi.mock('@/service/access-control', () => ({ + useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args), + useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args), + useUpdateAccessMode: () => mockUseUpdateAccessMode(), +})) + +describe('AccessControl', () => { + beforeEach(() => { + vi.clearAllMocks() + mockWebappAuth = { + enabled: true, + allow_sso: true, + allow_email_password_login: false, + allow_email_code_login: false, + } + useAccessControlStore.setState({ + appId: '', + specificGroups: [], + specificMembers: [], + currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS, + selectedGroupsForBreadcrumb: [], + }) + mockMutateAsync.mockResolvedValue(undefined) + mockUseAppWhiteListSubjects.mockReturnValue({ + isPending: false, + data: { + groups: [], + members: [], + }, + }) + mockUseSearchForWhiteListCandidates.mockReturnValue({ + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + data: { pages: [] }, + }) + }) + + it('should initialize menu from the app and update access mode on confirm', async () => { + const onClose = vi.fn() + const onConfirm = vi.fn() + const toastSpy = vi.spyOn(toast, 'success').mockReturnValue('toast-success') + const app = { + id: 'app-id-1', + access_mode: AccessMode.PUBLIC, + } as App + + render( + , + ) + + await waitFor(() => { + expect(useAccessControlStore.getState().appId).toBe(app.id) + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.PUBLIC) + }) + + fireEvent.click(screen.getByText('common.operation.confirm')) + + await waitFor(() => { + expect(mockMutateAsync).toHaveBeenCalledWith({ + appId: app.id, + accessMode: AccessMode.PUBLIC, + }) + expect(toastSpy).toHaveBeenCalledWith('app.accessControlDialog.updateSuccess') + expect(onConfirm).toHaveBeenCalledTimes(1) + }) + }) + + it('should show the external-members option when SSO tip is visible', () => { + mockWebappAuth = { + enabled: false, + allow_sso: false, + allow_email_password_login: false, + allow_email_code_login: false, + } + + render( + , + ) + + expect(screen.getByText('app.accessControlDialog.accessItems.external')).toBeInTheDocument() + expect(screen.getByText('app.accessControlDialog.accessItems.anyone')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/app-access-control/__tests__/specific-groups-or-members.spec.tsx b/web/app/components/app/app-access-control/__tests__/specific-groups-or-members.spec.tsx new file mode 100644 index 0000000000..7b198c4e66 --- /dev/null +++ b/web/app/components/app/app-access-control/__tests__/specific-groups-or-members.spec.tsx @@ -0,0 +1,97 @@ +import type { AccessControlAccount, AccessControlGroup } from '@/models/access-control' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import useAccessControlStore from '@/context/access-control-store' +import { AccessMode } from '@/models/access-control' +import SpecificGroupsOrMembers from '../specific-groups-or-members' + +const mockUseAppWhiteListSubjects = vi.fn() + +vi.mock('@/service/access-control', () => ({ + useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args), +})) + +vi.mock('../add-member-or-group-pop', () => ({ + default: () =>
, +})) + +const createGroup = (overrides: Partial = {}): AccessControlGroup => ({ + id: 'group-1', + name: 'Group One', + groupSize: 5, + ...overrides, +} as AccessControlGroup) + +const createMember = (overrides: Partial = {}): AccessControlAccount => ({ + id: 'member-1', + name: 'Member One', + email: 'member@example.com', + avatar: '', + avatarUrl: '', + ...overrides, +} as AccessControlAccount) + +describe('SpecificGroupsOrMembers', () => { + const baseGroup = createGroup() + const baseMember = createMember() + + beforeEach(() => { + vi.clearAllMocks() + useAccessControlStore.setState({ + appId: '', + specificGroups: [], + specificMembers: [], + currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS, + selectedGroupsForBreadcrumb: [], + }) + mockUseAppWhiteListSubjects.mockReturnValue({ + isPending: false, + data: { + groups: [baseGroup], + members: [baseMember], + }, + }) + }) + + it('should render the collapsed row when not in specific mode', () => { + useAccessControlStore.setState({ + currentMenu: AccessMode.ORGANIZATION, + }) + + render() + + expect(screen.getByText('app.accessControlDialog.accessItems.specific')).toBeInTheDocument() + expect(screen.queryByTestId('add-member-or-group-dialog')).not.toBeInTheDocument() + }) + + it('should show loading while whitelist subjects are pending', async () => { + mockUseAppWhiteListSubjects.mockReturnValue({ + isPending: true, + data: undefined, + }) + + const { container } = render() + + await waitFor(() => { + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + }) + + it('should render fetched groups and members and support removal', async () => { + useAccessControlStore.setState({ appId: 'app-1' }) + + render() + + await waitFor(() => { + expect(screen.getByText(baseGroup.name)).toBeInTheDocument() + expect(screen.getByText(baseMember.name)).toBeInTheDocument() + }) + + const groupRemove = screen.getByText(baseGroup.name).closest('div')?.querySelector('.h-4.w-4.cursor-pointer') as HTMLElement + fireEvent.click(groupRemove) + expect(useAccessControlStore.getState().specificGroups).toEqual([]) + + const memberRemove = screen.getByText(baseMember.name).closest('div')?.querySelector('.h-4.w-4.cursor-pointer') as HTMLElement + fireEvent.click(memberRemove) + expect(useAccessControlStore.getState().specificMembers).toEqual([]) + }) +}) diff --git a/web/app/components/app/configuration/config-var/__tests__/input-type-icon.spec.tsx b/web/app/components/app/configuration/config-var/__tests__/input-type-icon.spec.tsx new file mode 100644 index 0000000000..0b492a06ed --- /dev/null +++ b/web/app/components/app/configuration/config-var/__tests__/input-type-icon.spec.tsx @@ -0,0 +1,26 @@ +import { render, screen } from '@testing-library/react' +import { InputVarType } from '@/app/components/workflow/types' +import InputTypeIcon from '../input-type-icon' + +const mockInputVarTypeIcon = vi.fn(({ type, className }: { type: InputVarType, className?: string }) => ( +
+)) + +vi.mock('@/app/components/workflow/nodes/_base/components/input-var-type-icon', () => ({ + default: (props: { type: InputVarType, className?: string }) => mockInputVarTypeIcon(props), +})) + +describe('InputTypeIcon', () => { + it('should map string variables to the workflow text-input icon', () => { + render() + + expect(screen.getByTestId('input-var-type-icon')).toHaveAttribute('data-type', InputVarType.textInput) + expect(screen.getByTestId('input-var-type-icon')).toHaveClass('marker') + }) + + it('should map select variables to the workflow select icon', () => { + render() + + expect(screen.getByTestId('input-var-type-icon')).toHaveAttribute('data-type', InputVarType.select) + }) +}) diff --git a/web/app/components/app/configuration/config-var/__tests__/modal-foot.spec.tsx b/web/app/components/app/configuration/config-var/__tests__/modal-foot.spec.tsx new file mode 100644 index 0000000000..e84189ddff --- /dev/null +++ b/web/app/components/app/configuration/config-var/__tests__/modal-foot.spec.tsx @@ -0,0 +1,19 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import ModalFoot from '../modal-foot' + +describe('ModalFoot', () => { + it('should trigger cancel and confirm callbacks', () => { + const onCancel = vi.fn() + const onConfirm = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + + expect(onCancel).toHaveBeenCalledTimes(1) + expect(onConfirm).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/app/configuration/config-var/__tests__/select-var-type.spec.tsx b/web/app/components/app/configuration/config-var/__tests__/select-var-type.spec.tsx new file mode 100644 index 0000000000..611aaa1c8a --- /dev/null +++ b/web/app/components/app/configuration/config-var/__tests__/select-var-type.spec.tsx @@ -0,0 +1,16 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import SelectVarType from '../select-var-type' + +describe('SelectVarType', () => { + it('should open the menu and return the selected variable type', () => { + const onChange = vi.fn() + + render() + + fireEvent.click(screen.getByText('common.operation.add')) + fireEvent.click(screen.getByText('appDebug.variableConfig.checkbox')) + + expect(onChange).toHaveBeenCalledWith('checkbox') + expect(screen.queryByText('appDebug.variableConfig.checkbox')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config-var/__tests__/var-item.spec.tsx b/web/app/components/app/configuration/config-var/__tests__/var-item.spec.tsx new file mode 100644 index 0000000000..aae00bb2b7 --- /dev/null +++ b/web/app/components/app/configuration/config-var/__tests__/var-item.spec.tsx @@ -0,0 +1,46 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import VarItem from '../var-item' + +describe('VarItem', () => { + it('should render variable metadata and allow editing', () => { + const onEdit = vi.fn() + const onRemove = vi.fn() + const { container } = render( + , + ) + + expect(screen.getByTitle('api_key · API Key')).toBeInTheDocument() + expect(screen.getByText('required')).toBeInTheDocument() + + const editButton = container.querySelector('.mr-1.flex.h-6.w-6') as HTMLElement + fireEvent.click(editButton) + + expect(onEdit).toHaveBeenCalledTimes(1) + }) + + it('should call remove when clicking the delete action', () => { + const onRemove = vi.fn() + render( + , + ) + + fireEvent.click(screen.getByTestId('var-item-delete-btn')) + + expect(onRemove).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/app/configuration/config-var/config-modal/__tests__/config.spec.ts b/web/app/components/app/configuration/config-var/config-modal/__tests__/config.spec.ts new file mode 100644 index 0000000000..efa2f793ae --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-modal/__tests__/config.spec.ts @@ -0,0 +1,23 @@ +import { jsonConfigPlaceHolder } from '../config' + +describe('config modal placeholder config', () => { + it('should contain a valid object schema example', () => { + const parsed = JSON.parse(jsonConfigPlaceHolder) as { + type: string + properties: { + foo: { type: string } + bar: { + type: string + properties: { + sub: { type: string } + } + } + } + } + + expect(parsed.type).toBe('object') + expect(parsed.properties.foo.type).toBe('string') + expect(parsed.properties.bar.type).toBe('object') + expect(parsed.properties.bar.properties.sub.type).toBe('number') + }) +}) diff --git a/web/app/components/app/configuration/config-var/config-modal/__tests__/field.spec.tsx b/web/app/components/app/configuration/config-var/config-modal/__tests__/field.spec.tsx new file mode 100644 index 0000000000..454e5dd444 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-modal/__tests__/field.spec.tsx @@ -0,0 +1,25 @@ +import { render, screen } from '@testing-library/react' +import Field from '../field' + +describe('ConfigModal Field', () => { + it('should render the title and children', () => { + render( + + + , + ) + + expect(screen.getByText('Field title')).toBeInTheDocument() + expect(screen.getByLabelText('field-input')).toBeInTheDocument() + }) + + it('should render the optional hint when requested', () => { + render( + + + , + ) + + expect(screen.getByText(/\(appDebug\.variableConfig\.optional\)/)).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config-vision/__tests__/param-config-content.spec.tsx b/web/app/components/app/configuration/config-vision/__tests__/param-config-content.spec.tsx new file mode 100644 index 0000000000..2cb919b6db --- /dev/null +++ b/web/app/components/app/configuration/config-vision/__tests__/param-config-content.spec.tsx @@ -0,0 +1,74 @@ +import type { FeatureStoreState } from '@/app/components/base/features/store' +import type { FileUpload } from '@/app/components/base/features/types' +import { fireEvent, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { Resolution, TransferMethod } from '@/types/app' +import ParamConfigContent from '../param-config-content' + +const mockUseFeatures = vi.fn() +const mockUseFeaturesStore = vi.fn() +const mockSetFeatures = vi.fn() + +vi.mock('@/app/components/base/features/hooks', () => ({ + useFeatures: (selector: (state: FeatureStoreState) => unknown) => mockUseFeatures(selector), + useFeaturesStore: () => mockUseFeaturesStore(), +})) + +const setupFeatureStore = (fileOverrides: Partial = {}) => { + const file: FileUpload = { + enabled: true, + allowed_file_types: [], + allowed_file_upload_methods: [TransferMethod.local_file, TransferMethod.remote_url], + number_limits: 3, + image: { + enabled: true, + detail: Resolution.low, + number_limits: 3, + transfer_methods: [TransferMethod.local_file, TransferMethod.remote_url], + }, + ...fileOverrides, + } + const featureStoreState = { + features: { file }, + setFeatures: mockSetFeatures, + showFeaturesModal: false, + setShowFeaturesModal: vi.fn(), + } as unknown as FeatureStoreState + + mockUseFeatures.mockImplementation(selector => selector(featureStoreState)) + mockUseFeaturesStore.mockReturnValue({ + getState: () => featureStoreState, + }) +} + +const getUpdatedFile = () => { + expect(mockSetFeatures).toHaveBeenCalled() + return mockSetFeatures.mock.calls.at(-1)?.[0].file as FileUpload +} + +describe('ParamConfigContent', () => { + beforeEach(() => { + vi.clearAllMocks() + setupFeatureStore() + }) + + it('should update the image resolution', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByText('appDebug.vision.visionSettings.high')) + + expect(getUpdatedFile().image?.detail).toBe(Resolution.high) + }) + + it('should update upload methods and upload limit', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByText('appDebug.vision.visionSettings.localUpload')) + expect(getUpdatedFile().allowed_file_upload_methods).toEqual([TransferMethod.local_file]) + + fireEvent.change(screen.getByRole('textbox'), { target: { value: '5' } }) + expect(getUpdatedFile().number_limits).toBe(5) + }) +}) diff --git a/web/app/components/app/configuration/config-vision/__tests__/param-config.spec.tsx b/web/app/components/app/configuration/config-vision/__tests__/param-config.spec.tsx new file mode 100644 index 0000000000..617f14629e --- /dev/null +++ b/web/app/components/app/configuration/config-vision/__tests__/param-config.spec.tsx @@ -0,0 +1,58 @@ +import type { FeatureStoreState } from '@/app/components/base/features/store' +import type { FileUpload } from '@/app/components/base/features/types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { Resolution, TransferMethod } from '@/types/app' +import ParamConfig from '../param-config' + +const mockUseFeatures = vi.fn() +const mockUseFeaturesStore = vi.fn() + +vi.mock('@/app/components/base/features/hooks', () => ({ + useFeatures: (selector: (state: FeatureStoreState) => unknown) => mockUseFeatures(selector), + useFeaturesStore: () => mockUseFeaturesStore(), +})) + +const setupFeatureStore = (fileOverrides: Partial = {}) => { + const file: FileUpload = { + enabled: true, + allowed_file_types: [], + allowed_file_upload_methods: [TransferMethod.local_file, TransferMethod.remote_url], + number_limits: 3, + image: { + enabled: true, + detail: Resolution.low, + number_limits: 3, + transfer_methods: [TransferMethod.local_file, TransferMethod.remote_url], + }, + ...fileOverrides, + } + const featureStoreState = { + features: { file }, + setFeatures: vi.fn(), + showFeaturesModal: false, + setShowFeaturesModal: vi.fn(), + } as unknown as FeatureStoreState + mockUseFeatures.mockImplementation(selector => selector(featureStoreState)) + mockUseFeaturesStore.mockReturnValue({ + getState: () => featureStoreState, + }) +} + +describe('ParamConfig', () => { + beforeEach(() => { + vi.clearAllMocks() + setupFeatureStore() + }) + + it('should toggle the settings panel when clicking the trigger', async () => { + const user = userEvent.setup() + render() + + expect(screen.queryByText('appDebug.vision.visionSettings.title')).not.toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'appDebug.voice.settings' })) + + expect(await screen.findByText('appDebug.vision.visionSettings.title')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config/automatic/__tests__/prompt-toast.spec.tsx b/web/app/components/app/configuration/config/automatic/__tests__/prompt-toast.spec.tsx new file mode 100644 index 0000000000..bc380d35d0 --- /dev/null +++ b/web/app/components/app/configuration/config/automatic/__tests__/prompt-toast.spec.tsx @@ -0,0 +1,22 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import PromptToast from '../prompt-toast' + +describe('PromptToast', () => { + it('should render the note title and markdown message', () => { + render() + + expect(screen.getByText('appDebug.generate.optimizationNote')).toBeInTheDocument() + expect(screen.getByTestId('markdown-body')).toBeInTheDocument() + }) + + it('should collapse and expand the markdown content', () => { + const { container } = render() + + const toggle = container.querySelector('.cursor-pointer') as HTMLElement + fireEvent.click(toggle) + expect(screen.queryByTestId('markdown-body')).not.toBeInTheDocument() + + fireEvent.click(toggle) + expect(screen.getByTestId('markdown-body')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config/automatic/__tests__/res-placeholder.spec.tsx b/web/app/components/app/configuration/config/automatic/__tests__/res-placeholder.spec.tsx new file mode 100644 index 0000000000..cbdbda8480 --- /dev/null +++ b/web/app/components/app/configuration/config/automatic/__tests__/res-placeholder.spec.tsx @@ -0,0 +1,10 @@ +import { render, screen } from '@testing-library/react' +import ResPlaceholder from '../res-placeholder' + +describe('ResPlaceholder', () => { + it('should render the placeholder copy', () => { + render() + + expect(screen.getByText('appDebug.generate.newNoDataLine1')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config/automatic/__tests__/use-gen-data.spec.ts b/web/app/components/app/configuration/config/automatic/__tests__/use-gen-data.spec.ts new file mode 100644 index 0000000000..374a75cd7b --- /dev/null +++ b/web/app/components/app/configuration/config/automatic/__tests__/use-gen-data.spec.ts @@ -0,0 +1,39 @@ +import type { GenRes } from '@/service/debug' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it } from 'vitest' +import useGenData from '../use-gen-data' + +describe('useGenData', () => { + beforeEach(() => { + sessionStorage.clear() + }) + + it('should start with an empty version list', () => { + const { result } = renderHook(() => useGenData({ storageKey: 'prompt' })) + + expect(result.current.versions).toEqual([]) + expect(result.current.currentVersionIndex).toBe(0) + expect(result.current.current).toBeUndefined() + }) + + it('should append versions and keep the latest one selected', () => { + const versionOne = { modified: 'first version' } as GenRes + const versionTwo = { modified: 'second version' } as GenRes + const { result } = renderHook(() => useGenData({ storageKey: 'prompt' })) + + act(() => { + result.current.addVersion(versionOne) + }) + + expect(result.current.versions).toEqual([versionOne]) + expect(result.current.current).toEqual(versionOne) + + act(() => { + result.current.addVersion(versionTwo) + }) + + expect(result.current.versions).toEqual([versionOne, versionTwo]) + expect(result.current.currentVersionIndex).toBe(1) + expect(result.current.current).toEqual(versionTwo) + }) +}) diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/__tests__/context-provider.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/__tests__/context-provider.spec.tsx new file mode 100644 index 0000000000..5608f4c5a2 --- /dev/null +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/__tests__/context-provider.spec.tsx @@ -0,0 +1,38 @@ +import { render, screen } from '@testing-library/react' +import { useDebugWithMultipleModelContext } from '../context' +import { DebugWithMultipleModelContextProvider } from '../context-provider' + +const ContextConsumer = () => { + const value = useDebugWithMultipleModelContext() + return ( +
+
{value.multipleModelConfigs.length}
+ + +
{String(value.checkCanSend?.())}
+
+ ) +} + +describe('DebugWithMultipleModelContextProvider', () => { + it('should expose the provided context value to descendants', () => { + const onMultipleModelConfigsChange = vi.fn() + const onDebugWithMultipleModelChange = vi.fn() + const checkCanSend = vi.fn(() => true) + const multipleModelConfigs = [{ model: 'gpt-4o' }] as unknown as [] + + render( + + + , + ) + + expect(screen.getByText('1')).toBeInTheDocument() + expect(screen.getByText('true')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/overview/__tests__/app-card-sections.spec.tsx b/web/app/components/app/overview/__tests__/app-card-sections.spec.tsx new file mode 100644 index 0000000000..9a818e0fd7 --- /dev/null +++ b/web/app/components/app/overview/__tests__/app-card-sections.spec.tsx @@ -0,0 +1,103 @@ +import type { AppDetailResponse } from '@/models/app' +import { fireEvent, render, screen } from '@testing-library/react' +import { AccessMode } from '@/models/access-control' +import { AppModeEnum } from '@/types/app' +import { AppCardAccessControlSection, AppCardOperations, createAppCardOperations } from '../app-card-sections' + +describe('app-card-sections', () => { + const t = (key: string) => key + + it('should build operations with the expected disabled state', () => { + const onLaunch = vi.fn() + const operations = createAppCardOperations({ + operationKeys: ['launch', 'settings'], + t: t as never, + runningStatus: false, + triggerModeDisabled: false, + onLaunch, + onEmbedded: vi.fn(), + onCustomize: vi.fn(), + onSettings: vi.fn(), + onDevelop: vi.fn(), + }) + + expect(operations[0]).toMatchObject({ + key: 'launch', + disabled: true, + label: 'overview.appInfo.launch', + }) + expect(operations[1]).toMatchObject({ + key: 'settings', + disabled: false, + label: 'overview.appInfo.settings.entry', + }) + }) + + it('should render the access-control section and call onClick', () => { + const onClick = vi.fn() + render( + , + ) + + fireEvent.click(screen.getByText('publishApp.notSet')) + + expect(screen.getByText('accessControlDialog.accessItems.specific')).toBeInTheDocument() + expect(onClick).toHaveBeenCalledTimes(1) + }) + + it('should render operation buttons and execute enabled actions', () => { + const onLaunch = vi.fn() + const operations = createAppCardOperations({ + operationKeys: ['launch', 'embedded'], + t: t as never, + runningStatus: true, + triggerModeDisabled: false, + onLaunch, + onEmbedded: vi.fn(), + onCustomize: vi.fn(), + onSettings: vi.fn(), + onDevelop: vi.fn(), + }) + + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: /overview\.appInfo\.launch/i })) + + expect(onLaunch).toHaveBeenCalledTimes(1) + expect(screen.getByRole('button', { name: /overview\.appInfo\.embedded\.entry/i })).toBeInTheDocument() + }) + + it('should keep customize available for web app cards that are not completion or workflow apps', () => { + const operations = createAppCardOperations({ + operationKeys: ['customize'], + t: t as never, + runningStatus: true, + triggerModeDisabled: false, + onLaunch: vi.fn(), + onEmbedded: vi.fn(), + onCustomize: vi.fn(), + onSettings: vi.fn(), + onDevelop: vi.fn(), + }) + + render( + , + ) + + expect(screen.getByText('overview.appInfo.customize.entry')).toBeInTheDocument() + expect(AppModeEnum.CHAT).toBe('chat') + }) +}) diff --git a/web/app/components/app/overview/__tests__/app-card-utils.spec.ts b/web/app/components/app/overview/__tests__/app-card-utils.spec.ts new file mode 100644 index 0000000000..fbfcdaf955 --- /dev/null +++ b/web/app/components/app/overview/__tests__/app-card-utils.spec.ts @@ -0,0 +1,107 @@ +import type { AppDetailResponse } from '@/models/app' +import { BlockEnum } from '@/app/components/workflow/types' +import { AccessMode } from '@/models/access-control' +import { AppModeEnum } from '@/types/app' +import { basePath } from '@/utils/var' +import { getAppCardDisplayState, getAppCardOperationKeys, hasWorkflowStartNode, isAppAccessConfigured } from '../app-card-utils' + +describe('app-card-utils', () => { + const baseAppInfo = { + id: 'app-1', + mode: AppModeEnum.CHAT, + enable_site: true, + enable_api: false, + access_mode: AccessMode.PUBLIC, + api_base_url: 'https://api.example.com', + site: { + app_base_url: 'https://example.com', + access_token: 'token-1', + }, + } as AppDetailResponse + + it('should detect whether the workflow includes a start node', () => { + expect(hasWorkflowStartNode({ + graph: { + nodes: [{ data: { type: BlockEnum.Start } }], + }, + })).toBe(true) + + expect(hasWorkflowStartNode({ + graph: { + nodes: [{ data: { type: BlockEnum.Answer } }], + }, + })).toBe(false) + }) + + it('should build the display state for a published web app', () => { + const state = getAppCardDisplayState({ + appInfo: baseAppInfo, + cardType: 'webapp', + currentWorkflow: null, + isCurrentWorkspaceEditor: true, + isCurrentWorkspaceManager: true, + }) + + expect(state.isApp).toBe(true) + expect(state.appMode).toBe(AppModeEnum.CHAT) + expect(state.runningStatus).toBe(true) + expect(state.accessibleUrl).toBe(`https://example.com${basePath}/chat/token-1`) + }) + + it('should disable workflow cards without a graph or start node', () => { + const unpublishedState = getAppCardDisplayState({ + appInfo: { ...baseAppInfo, mode: AppModeEnum.WORKFLOW }, + cardType: 'webapp', + currentWorkflow: null, + isCurrentWorkspaceEditor: true, + isCurrentWorkspaceManager: true, + }) + expect(unpublishedState.appUnpublished).toBe(true) + expect(unpublishedState.toggleDisabled).toBe(true) + + const missingStartState = getAppCardDisplayState({ + appInfo: { ...baseAppInfo, mode: AppModeEnum.WORKFLOW }, + cardType: 'webapp', + currentWorkflow: { + graph: { + nodes: [{ data: { type: BlockEnum.Answer } }], + }, + }, + isCurrentWorkspaceEditor: true, + isCurrentWorkspaceManager: true, + }) + expect(missingStartState.missingStartNode).toBe(true) + expect(missingStartState.runningStatus).toBe(false) + }) + + it('should require specific access subjects only for the specific access mode', () => { + expect(isAppAccessConfigured( + { ...baseAppInfo, access_mode: AccessMode.PUBLIC }, + { groups: [], members: [] }, + )).toBe(true) + + expect(isAppAccessConfigured( + { ...baseAppInfo, access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS }, + { groups: [], members: [] }, + )).toBe(false) + + expect(isAppAccessConfigured( + { ...baseAppInfo, access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS }, + { groups: [{ id: 'group-1' }], members: [] }, + )).toBe(true) + }) + + it('should derive operation keys for api and webapp cards', () => { + expect(getAppCardOperationKeys({ + cardType: 'api', + appMode: AppModeEnum.COMPLETION, + isCurrentWorkspaceEditor: true, + })).toEqual(['develop']) + + expect(getAppCardOperationKeys({ + cardType: 'webapp', + appMode: AppModeEnum.CHAT, + isCurrentWorkspaceEditor: false, + })).toEqual(['launch', 'embedded', 'customize']) + }) +}) diff --git a/web/app/components/base/copy-feedback/__tests__/index.spec.tsx b/web/app/components/base/copy-feedback/__tests__/index.spec.tsx index 322a9970af..8cc22693b6 100644 --- a/web/app/components/base/copy-feedback/__tests__/index.spec.tsx +++ b/web/app/components/base/copy-feedback/__tests__/index.spec.tsx @@ -5,7 +5,7 @@ const mockCopy = vi.fn() const mockReset = vi.fn() let mockCopied = false -vi.mock('foxact/use-clipboard', () => ({ +vi.mock('@/hooks/use-clipboard', () => ({ useClipboard: () => ({ copy: mockCopy, reset: mockReset, diff --git a/web/app/components/base/copy-feedback/index.tsx b/web/app/components/base/copy-feedback/index.tsx index 80b35eb3a8..5210066670 100644 --- a/web/app/components/base/copy-feedback/index.tsx +++ b/web/app/components/base/copy-feedback/index.tsx @@ -3,11 +3,11 @@ import { RiClipboardFill, RiClipboardLine, } from '@remixicon/react' -import { useClipboard } from 'foxact/use-clipboard' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' import Tooltip from '@/app/components/base/tooltip' +import { useClipboard } from '@/hooks/use-clipboard' import copyStyle from './style.module.css' type Props = { diff --git a/web/app/components/base/copy-icon/__tests__/index.spec.tsx b/web/app/components/base/copy-icon/__tests__/index.spec.tsx index 3db76ef606..1ce9e6dbf5 100644 --- a/web/app/components/base/copy-icon/__tests__/index.spec.tsx +++ b/web/app/components/base/copy-icon/__tests__/index.spec.tsx @@ -5,7 +5,7 @@ const copy = vi.fn() const reset = vi.fn() let copied = false -vi.mock('foxact/use-clipboard', () => ({ +vi.mock('@/hooks/use-clipboard', () => ({ useClipboard: () => ({ copy, reset, diff --git a/web/app/components/base/copy-icon/index.tsx b/web/app/components/base/copy-icon/index.tsx index 78c0fcb8c3..15332592d0 100644 --- a/web/app/components/base/copy-icon/index.tsx +++ b/web/app/components/base/copy-icon/index.tsx @@ -1,7 +1,7 @@ 'use client' -import { useClipboard } from 'foxact/use-clipboard' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' +import { useClipboard } from '@/hooks/use-clipboard' import Tooltip from '../tooltip' type Props = { diff --git a/web/app/components/base/input-with-copy/__tests__/index.spec.tsx b/web/app/components/base/input-with-copy/__tests__/index.spec.tsx index 201c419444..33ebec5cbc 100644 --- a/web/app/components/base/input-with-copy/__tests__/index.spec.tsx +++ b/web/app/components/base/input-with-copy/__tests__/index.spec.tsx @@ -6,7 +6,7 @@ const mockCopy = vi.fn() let mockCopied = false const mockReset = vi.fn() -vi.mock('foxact/use-clipboard', () => ({ +vi.mock('@/hooks/use-clipboard', () => ({ useClipboard: () => ({ copy: mockCopy, copied: mockCopied, diff --git a/web/app/components/base/input-with-copy/index.tsx b/web/app/components/base/input-with-copy/index.tsx index e85a7bd6f4..33db47baaa 100644 --- a/web/app/components/base/input-with-copy/index.tsx +++ b/web/app/components/base/input-with-copy/index.tsx @@ -1,8 +1,8 @@ 'use client' import type { InputProps } from '../input' -import { useClipboard } from 'foxact/use-clipboard' import * as React from 'react' import { useTranslation } from 'react-i18next' +import { useClipboard } from '@/hooks/use-clipboard' import { cn } from '@/utils/classnames' import ActionButton from '../action-button' import Tooltip from '../tooltip' diff --git a/web/app/components/base/markdown/streamdown-wrapper.tsx b/web/app/components/base/markdown/streamdown-wrapper.tsx index 46db301adb..e20898135b 100644 --- a/web/app/components/base/markdown/streamdown-wrapper.tsx +++ b/web/app/components/base/markdown/streamdown-wrapper.tsx @@ -16,7 +16,7 @@ import { ThinkBlock, VideoBlock, } from '@/app/components/base/markdown-blocks' -import { ENABLE_SINGLE_DOLLAR_LATEX } from '@/config' +import { ALLOW_INLINE_STYLES, ENABLE_SINGLE_DOLLAR_LATEX } from '@/config' import dynamic from '@/next/dynamic' import { customUrlTransform } from './markdown-utils' import 'katex/dist/katex.min.css' @@ -118,6 +118,11 @@ function buildRehypePlugins(extraPlugins?: PluggableList): PluggableList { // component validates names with `isSafeName()`, so remove it. const clobber = (defaultSanitizeSchema.clobber ?? []).filter(k => k !== 'name') + if (ALLOW_INLINE_STYLES) { + const globalAttrs = mergedAttributes['*'] ?? [] + mergedAttributes['*'] = [...globalAttrs, 'style'] + } + const customSchema: SanitizeSchema = { ...defaultSanitizeSchema, tagNames: [...tagNamesSet], diff --git a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/index.spec.tsx index b5f38cdd1b..62b867d155 100644 --- a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/index.spec.tsx @@ -120,7 +120,10 @@ describe('HITLInputBlock', () => { }) await waitFor(() => { - expect(onWorkflowMapUpdate).toHaveBeenCalledWith(workflowNodesMap) + expect(onWorkflowMapUpdate).toHaveBeenCalledWith({ + workflowNodesMap, + availableVariables: [], + }) }) }) }) diff --git a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/variable-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/variable-block.spec.tsx index c848d08c5c..db3e474b60 100644 --- a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/variable-block.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/variable-block.spec.tsx @@ -148,7 +148,10 @@ describe('HITLInputVariableBlockComponent', () => { editor!.update(() => { $getRoot().selectEnd() }) - handled = editor!.dispatchCommand(UPDATE_WORKFLOW_NODES_MAP, createWorkflowNodesMap()) + handled = editor!.dispatchCommand(UPDATE_WORKFLOW_NODES_MAP, { + workflowNodesMap: createWorkflowNodesMap(), + availableVariables: [], + }) }) expect(handled).toBe(true) diff --git a/web/app/components/base/prompt-editor/plugins/hitl-input-block/hitl-input-block-replacement-block.tsx b/web/app/components/base/prompt-editor/plugins/hitl-input-block/hitl-input-block-replacement-block.tsx index cd1515c57d..0da99b9155 100644 --- a/web/app/components/base/prompt-editor/plugins/hitl-input-block/hitl-input-block-replacement-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/hitl-input-block/hitl-input-block-replacement-block.tsx @@ -22,7 +22,7 @@ const HITLInputReplacementBlock = ({ onFormInputsChange, onFormInputItemRename, onFormInputItemRemove, - workflowNodesMap, + workflowNodesMap = {}, getVarType, variables, readonly, diff --git a/web/app/components/base/prompt-editor/plugins/hitl-input-block/index.tsx b/web/app/components/base/prompt-editor/plugins/hitl-input-block/index.tsx index 2c10fdbd5a..1b2af39ebe 100644 --- a/web/app/components/base/prompt-editor/plugins/hitl-input-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/hitl-input-block/index.tsx @@ -14,6 +14,7 @@ import { useEffect, } from 'react' import { CustomTextNode } from '../custom-text/node' +import { UPDATE_WORKFLOW_NODES_MAP as WORKFLOW_UPDATE_WORKFLOW_NODES_MAP } from '../workflow-variable-block' import { $createHITLInputNode, HITLInputNode, @@ -21,11 +22,13 @@ import { export const INSERT_HITL_INPUT_BLOCK_COMMAND = createCommand('INSERT_HITL_INPUT_BLOCK_COMMAND') export const DELETE_HITL_INPUT_BLOCK_COMMAND = createCommand('DELETE_HITL_INPUT_BLOCK_COMMAND') -export const UPDATE_WORKFLOW_NODES_MAP = createCommand('UPDATE_WORKFLOW_NODES_MAP') +export const UPDATE_WORKFLOW_NODES_MAP = WORKFLOW_UPDATE_WORKFLOW_NODES_MAP + const HITLInputBlock = memo(({ onInsert, onDelete, - workflowNodesMap, + workflowNodesMap = {}, + variables: workflowAvailableVariables, getVarType, readonly, }: HITLInputBlockType) => { @@ -33,9 +36,12 @@ const HITLInputBlock = memo(({ useEffect(() => { editor.update(() => { - editor.dispatchCommand(UPDATE_WORKFLOW_NODES_MAP, workflowNodesMap) + editor.dispatchCommand(UPDATE_WORKFLOW_NODES_MAP, { + workflowNodesMap: workflowNodesMap || {}, + availableVariables: workflowAvailableVariables || [], + }) }) - }, [editor, workflowNodesMap]) + }, [editor, workflowNodesMap, workflowAvailableVariables]) useEffect(() => { if (!editor.hasNodes([HITLInputNode])) diff --git a/web/app/components/base/prompt-editor/plugins/hitl-input-block/variable-block.tsx b/web/app/components/base/prompt-editor/plugins/hitl-input-block/variable-block.tsx index b1374b994f..a466d64eff 100644 --- a/web/app/components/base/prompt-editor/plugins/hitl-input-block/variable-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/hitl-input-block/variable-block.tsx @@ -1,3 +1,4 @@ +import type { UpdateWorkflowNodesMapPayload } from '../workflow-variable-block' import type { WorkflowNodesMap } from '../workflow-variable-block/node' import type { ValueSelector, Var } from '@/app/components/workflow/types' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' @@ -98,9 +99,8 @@ const HITLInputVariableBlockComponent = ({ return mergeRegister( editor.registerCommand( UPDATE_WORKFLOW_NODES_MAP, - (workflowNodesMap: WorkflowNodesMap) => { - setLocalWorkflowNodesMap(workflowNodesMap) - + (payload: UpdateWorkflowNodesMapPayload) => { + setLocalWorkflowNodesMap(payload.workflowNodesMap) return true }, COMMAND_PRIORITY_EDITOR, diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/component.spec.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/component.spec.tsx index ff064f2a99..a6cb70ddb6 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/component.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/component.spec.tsx @@ -1,4 +1,5 @@ import type { LexicalEditor } from 'lexical' +import type { UpdateWorkflowNodesMapPayload } from '../index' import type { ValueSelector, Var } from '@/app/components/workflow/types' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { mergeRegister } from '@lexical/utils' @@ -216,7 +217,7 @@ describe('WorkflowVariableBlockComponent', () => { }) }) - it('should mark env variable invalid when not found in environmentVariables', () => { + it('should treat env variable as valid regardless of environmentVariables contents', () => { const environmentVariables: Var[] = [{ variable: 'env.valid_key', type: VarType.string }] render( @@ -229,7 +230,7 @@ describe('WorkflowVariableBlockComponent', () => { ) expect(mockVarLabel).toHaveBeenCalledWith(expect.objectContaining({ - errorMsg: expect.any(String), + errorMsg: undefined, })) }) @@ -281,7 +282,7 @@ describe('WorkflowVariableBlockComponent', () => { })) }) - it('should evaluate env fallback selector tokens when classifier is forced', () => { + it('should mark forced env branch invalid when selector prefix is missing', () => { mockForcedVariableKind.value = 'env' const environmentVariables: Var[] = [{ variable: '.', type: VarType.string }] @@ -295,7 +296,7 @@ describe('WorkflowVariableBlockComponent', () => { ) expect(mockVarLabel).toHaveBeenCalledWith(expect.objectContaining({ - errorMsg: undefined, + errorMsg: expect.any(String), })) }) @@ -330,7 +331,7 @@ describe('WorkflowVariableBlockComponent', () => { })) }) - it('should mark conversation variable invalid when not found in conversationVariables', () => { + it('should treat conversation variable as valid regardless of conversationVariables contents', () => { const conversationVariables: Var[] = [{ variable: 'conversation.other', type: VarType.string }] render( @@ -343,7 +344,7 @@ describe('WorkflowVariableBlockComponent', () => { ) expect(mockVarLabel).toHaveBeenCalledWith(expect.objectContaining({ - errorMsg: expect.any(String), + errorMsg: undefined, })) }) @@ -364,7 +365,7 @@ describe('WorkflowVariableBlockComponent', () => { })) }) - it('should evaluate conversation fallback selector tokens when classifier is forced', () => { + it('should mark forced conversation branch invalid when selector prefix is missing', () => { mockForcedVariableKind.value = 'conversation' const conversationVariables: Var[] = [{ variable: '.', type: VarType.string }] @@ -378,7 +379,7 @@ describe('WorkflowVariableBlockComponent', () => { ) expect(mockVarLabel).toHaveBeenCalledWith(expect.objectContaining({ - errorMsg: undefined, + errorMsg: expect.any(String), })) }) @@ -427,7 +428,7 @@ describe('WorkflowVariableBlockComponent', () => { })) }) - it('should mark rag variable invalid when not found in ragVariables', () => { + it('should treat rag variable as valid regardless of ragVariables contents', () => { const ragVariables: Var[] = [{ variable: 'rag.shared.other', type: VarType.string }] render( @@ -440,7 +441,7 @@ describe('WorkflowVariableBlockComponent', () => { ) expect(mockVarLabel).toHaveBeenCalledWith(expect.objectContaining({ - errorMsg: expect.any(String), + errorMsg: undefined, })) }) @@ -461,7 +462,7 @@ describe('WorkflowVariableBlockComponent', () => { })) }) - it('should evaluate rag fallback selector tokens when classifier is forced', () => { + it('should mark forced rag branch invalid when selector prefix is missing', () => { mockForcedVariableKind.value = 'rag' const ragVariables: Var[] = [{ variable: '..', type: VarType.string }] @@ -475,7 +476,7 @@ describe('WorkflowVariableBlockComponent', () => { ) expect(mockVarLabel).toHaveBeenCalledWith(expect.objectContaining({ - errorMsg: undefined, + errorMsg: expect.any(String), })) }) @@ -488,20 +489,81 @@ describe('WorkflowVariableBlockComponent', () => { />, ) - const updateHandler = mockRegisterCommand.mock.calls[0][1] as (map: Record) => boolean + const updateHandler = mockRegisterCommand.mock.calls[0][1] as (payload: UpdateWorkflowNodesMapPayload) => boolean let result = false act(() => { result = updateHandler({ - 'node-1': { - title: 'Updated', - type: BlockEnum.LLM, - width: 100, - height: 50, - position: { x: 0, y: 0 }, + workflowNodesMap: { + 'node-1': { + title: 'Updated', + type: BlockEnum.LLM, + width: 100, + height: 50, + position: { x: 0, y: 0 }, + }, }, + availableVariables: [], }) }) expect(result).toBe(true) }) + + it('should mark non-special variable invalid when source key is missing in availableVariables', () => { + render( + , + ) + + expect(mockVarLabel).toHaveBeenCalledWith(expect.objectContaining({ + errorMsg: expect.any(String), + })) + }) + + it('should keep non-special variable valid when source key exists in availableVariables', () => { + render( + , + ) + + expect(mockVarLabel).toHaveBeenCalledWith(expect.objectContaining({ + errorMsg: undefined, + })) + }) }) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx index 1591dc44f9..00b5b66660 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx @@ -105,7 +105,10 @@ describe('WorkflowVariableBlock', () => { ) expect(mockUpdate).toHaveBeenCalled() - expect(mockDispatchCommand).toHaveBeenCalledWith(UPDATE_WORKFLOW_NODES_MAP, workflowNodesMap) + expect(mockDispatchCommand).toHaveBeenCalledWith(UPDATE_WORKFLOW_NODES_MAP, { + workflowNodesMap, + availableVariables: [], + }) }) it('should throw when WorkflowVariableBlockNode is not registered', () => { @@ -137,6 +140,7 @@ describe('WorkflowVariableBlock', () => { ['node-1', 'answer'], workflowNodesMap, getVarType, + [], ) expect($insertNodes).toHaveBeenCalledWith([{ id: 'workflow-node' }]) expect(onInsert).toHaveBeenCalledTimes(1) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/node.spec.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/node.spec.tsx index 8d7a1cc33d..4154cd2fd9 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/node.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/node.spec.tsx @@ -1,5 +1,5 @@ import type { Klass, LexicalEditor, LexicalNode } from 'lexical' -import type { Var } from '@/app/components/workflow/types' +import type { NodeOutPutVar } from '@/app/components/workflow/types' import { createEditor } from 'lexical' import { Type } from '@/app/components/workflow/nodes/llm/types' import { BlockEnum, VarType } from '@/app/components/workflow/types' @@ -57,45 +57,43 @@ describe('WorkflowVariableBlockNode', () => { it('should decorate with component props from node state', () => { runInEditor(() => { const getVarType = vi.fn(() => Type.number) - const environmentVariables: Var[] = [{ variable: 'env.key', type: VarType.string }] - const conversationVariables: Var[] = [{ variable: 'conversation.topic', type: VarType.string }] - const ragVariables: Var[] = [{ variable: 'rag.shared.answer', type: VarType.string }] + const availableVariables: NodeOutPutVar[] = [{ + nodeId: 'node-1', + title: 'Node A', + vars: [{ variable: 'answer', type: VarType.string }], + }] const node = new WorkflowVariableBlockNode( ['node-1', 'answer'], { 'node-1': { title: 'A', type: BlockEnum.LLM } }, getVarType, 'decorator-key', - environmentVariables, - conversationVariables, - ragVariables, + availableVariables, ) const decorated = node.decorate() expect(decorated.props.nodeKey).toBe('decorator-key') expect(decorated.props.variables).toEqual(['node-1', 'answer']) expect(decorated.props.workflowNodesMap).toEqual({ 'node-1': { title: 'A', type: BlockEnum.LLM } }) - expect(decorated.props.environmentVariables).toEqual(environmentVariables) - expect(decorated.props.conversationVariables).toEqual(conversationVariables) - expect(decorated.props.ragVariables).toEqual(ragVariables) + expect(decorated.props.availableVariables).toEqual(availableVariables) }) }) - it('should export and import json with full payload', () => { + it('should export and import json with available variables payload', () => { runInEditor(() => { const getVarType = vi.fn(() => Type.string) - const environmentVariables: Var[] = [{ variable: 'env.key', type: VarType.string }] - const conversationVariables: Var[] = [{ variable: 'conversation.topic', type: VarType.string }] - const ragVariables: Var[] = [{ variable: 'rag.shared.answer', type: VarType.string }] + const availableVariables: NodeOutPutVar[] = [{ + nodeId: 'node-1', + title: 'Node A', + vars: [{ variable: 'answer', type: VarType.string }], + }] const node = new WorkflowVariableBlockNode( ['node-1', 'answer'], { 'node-1': { title: 'A', type: BlockEnum.LLM } }, getVarType, undefined, - environmentVariables, - conversationVariables, - ragVariables, + availableVariables, ) expect(node.exportJSON()).toEqual({ @@ -104,9 +102,7 @@ describe('WorkflowVariableBlockNode', () => { variables: ['node-1', 'answer'], workflowNodesMap: { 'node-1': { title: 'A', type: BlockEnum.LLM } }, getVarType, - environmentVariables, - conversationVariables, - ragVariables, + availableVariables, }) const imported = WorkflowVariableBlockNode.importJSON({ @@ -115,48 +111,51 @@ describe('WorkflowVariableBlockNode', () => { variables: ['node-2', 'result'], workflowNodesMap: { 'node-2': { title: 'B', type: BlockEnum.Tool } }, getVarType, - environmentVariables, - conversationVariables, - ragVariables, + availableVariables, }) expect(imported).toBeInstanceOf(WorkflowVariableBlockNode) expect(imported.getVariables()).toEqual(['node-2', 'result']) expect(imported.getWorkflowNodesMap()).toEqual({ 'node-2': { title: 'B', type: BlockEnum.Tool } }) + expect(imported.getAvailableVariables()).toEqual(availableVariables) }) }) it('should return getters and text content in expected format', () => { runInEditor(() => { const getVarType = vi.fn(() => Type.string) - const environmentVariables: Var[] = [{ variable: 'env.key', type: VarType.string }] - const conversationVariables: Var[] = [{ variable: 'conversation.topic', type: VarType.string }] - const ragVariables: Var[] = [{ variable: 'rag.shared.answer', type: VarType.string }] + const availableVariables: NodeOutPutVar[] = [{ + nodeId: 'node-1', + title: 'Node A', + vars: [{ variable: 'answer', type: VarType.string }], + }] const node = new WorkflowVariableBlockNode( ['node-1', 'answer'], { 'node-1': { title: 'A', type: BlockEnum.LLM } }, getVarType, undefined, - environmentVariables, - conversationVariables, - ragVariables, + availableVariables, ) expect(node.getVariables()).toEqual(['node-1', 'answer']) expect(node.getWorkflowNodesMap()).toEqual({ 'node-1': { title: 'A', type: BlockEnum.LLM } }) expect(node.getVarType()).toBe(getVarType) - expect(node.getEnvironmentVariables()).toEqual(environmentVariables) - expect(node.getConversationVariables()).toEqual(conversationVariables) - expect(node.getRagVariables()).toEqual(ragVariables) + expect(node.getAvailableVariables()).toEqual(availableVariables) expect(node.getTextContent()).toBe('{{#node-1.answer#}}') }) }) it('should create node helper and type guard checks', () => { runInEditor(() => { - const node = $createWorkflowVariableBlockNode(['node-1', 'answer'], {}, undefined) + const availableVariables: NodeOutPutVar[] = [{ + nodeId: 'node-1', + title: 'Node A', + vars: [{ variable: 'answer', type: VarType.string }], + }] + const node = $createWorkflowVariableBlockNode(['node-1', 'answer'], {}, undefined, availableVariables) expect(node).toBeInstanceOf(WorkflowVariableBlockNode) + expect(node.getAvailableVariables()).toEqual(availableVariables) expect($isWorkflowVariableBlockNode(node)).toBe(true) expect($isWorkflowVariableBlockNode(null)).toBe(false) expect($isWorkflowVariableBlockNode(undefined)).toBe(false) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/workflow-variable-block-replacement-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/workflow-variable-block-replacement-block.spec.tsx index b9cb1faa37..9dcc37ec35 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/workflow-variable-block-replacement-block.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/workflow-variable-block-replacement-block.spec.tsx @@ -183,12 +183,7 @@ describe('WorkflowVariableBlockReplacementBlock', () => { ['node-1', 'output'], workflowNodesMap, getVarType, - variables[0].vars, - variables[1].vars, - [ - { variable: 'ragVarA', type: VarType.string, isRagVariable: true }, - { variable: 'rag.shared.answer', type: VarType.string, isRagVariable: true }, - ], + variables, ) expect($applyNodeReplacement).toHaveBeenCalledWith({ type: 'workflow-node' }) expect(created).toEqual({ type: 'workflow-node' }) @@ -214,8 +209,6 @@ describe('WorkflowVariableBlockReplacementBlock', () => { workflowNodesMap, undefined, [], - [], - undefined, ) }) }) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx index 2b46d1a378..bf91d25834 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx @@ -1,5 +1,8 @@ +import type { + UpdateWorkflowNodesMapPayload, +} from './index' import type { WorkflowNodesMap } from './node' -import type { ValueSelector, Var } from '@/app/components/workflow/types' +import type { NodeOutPutVar, ValueSelector, Var } from '@/app/components/workflow/types' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { mergeRegister } from '@lexical/utils' import { @@ -15,7 +18,7 @@ import { import { useTranslation } from 'react-i18next' import { useReactFlow, useStoreApi } from 'reactflow' import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' -import { isConversationVar, isENV, isGlobalVar, isRagVariableVar, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' +import { isRagVariableVar, isSpecialVar, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' import VarFullPathPanel from '@/app/components/workflow/nodes/_base/components/variable/var-full-path-panel' import { VariableLabelInEditor, @@ -34,6 +37,7 @@ type WorkflowVariableBlockComponentProps = { nodeKey: string variables: string[] workflowNodesMap: WorkflowNodesMap + availableVariables?: NodeOutPutVar[] environmentVariables?: Var[] conversationVariables?: Var[] ragVariables?: Var[] @@ -47,10 +51,8 @@ const WorkflowVariableBlockComponent = ({ nodeKey, variables, workflowNodesMap = {}, + availableVariables, getVarType, - environmentVariables, - conversationVariables, - ragVariables, }: WorkflowVariableBlockComponentProps) => { const { t } = useTranslation() const [editor] = useLexicalComposerContext() @@ -66,36 +68,25 @@ const WorkflowVariableBlockComponent = ({ } )() const [localWorkflowNodesMap, setLocalWorkflowNodesMap] = useState(workflowNodesMap) + const [localAvailableVariables, setLocalAvailableVariables] = useState(availableVariables || []) const node = localWorkflowNodesMap![variables[isRagVar ? 1 : 0]] const isException = isExceptionVariable(varName, node?.type) const sourceNodeId = variables[isRagVar ? 1 : 0] const isLlmModelInstalled = useLlmModelPluginInstalled(sourceNodeId, localWorkflowNodesMap) const variableValid = useMemo(() => { - let variableValid = true - const isEnv = isENV(variables) - const isChatVar = isConversationVar(variables) - const isGlobal = isGlobalVar(variables) - if (isGlobal) + if (isSpecialVar(variables[0] ?? '')) return true - if (isEnv) { - if (environmentVariables) - variableValid = environmentVariables.some(v => v.variable === `${variables?.[0] ?? ''}.${variables?.[1] ?? ''}`) - } - else if (isChatVar) { - if (conversationVariables) - variableValid = conversationVariables.some(v => v.variable === `${variables?.[0] ?? ''}.${variables?.[1] ?? ''}`) - } - else if (isRagVar) { - if (ragVariables) - variableValid = ragVariables.some(v => v.variable === `${variables?.[0] ?? ''}.${variables?.[1] ?? ''}.${variables?.[2] ?? ''}`) - } - else { - variableValid = !!node - } - return variableValid - }, [variables, node, environmentVariables, conversationVariables, isRagVar, ragVariables]) + if (!variables[1]) + return false + + const sourceNode = localAvailableVariables.find(v => v.nodeId === variables[0]) + if (!sourceNode) + return false + + return sourceNode.vars.some(v => v.variable === variables[1]) + }, [localAvailableVariables, variables]) const reactflow = useReactFlow() const store = useStoreApi() @@ -107,9 +98,9 @@ const WorkflowVariableBlockComponent = ({ return mergeRegister( editor.registerCommand( UPDATE_WORKFLOW_NODES_MAP, - (workflowNodesMap: WorkflowNodesMap) => { - setLocalWorkflowNodesMap(workflowNodesMap) - + (payload: UpdateWorkflowNodesMapPayload) => { + setLocalWorkflowNodesMap(payload.workflowNodesMap) + setLocalAvailableVariables(payload.availableVariables) return true }, COMMAND_PRIORITY_EDITOR, diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx index dfbd238dbf..ab79630f80 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx @@ -17,9 +17,14 @@ import { export const INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND') export const DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND') -export const UPDATE_WORKFLOW_NODES_MAP = createCommand('UPDATE_WORKFLOW_NODES_MAP') +export type UpdateWorkflowNodesMapPayload = { + workflowNodesMap: NonNullable + availableVariables: NonNullable +} +export const UPDATE_WORKFLOW_NODES_MAP = createCommand('UPDATE_WORKFLOW_NODES_MAP') const WorkflowVariableBlock = memo(({ - workflowNodesMap, + workflowNodesMap = {}, + variables: workflowAvailableVariables, onInsert, onDelete, getVarType, @@ -28,9 +33,12 @@ const WorkflowVariableBlock = memo(({ useEffect(() => { editor.update(() => { - editor.dispatchCommand(UPDATE_WORKFLOW_NODES_MAP, workflowNodesMap) + editor.dispatchCommand(UPDATE_WORKFLOW_NODES_MAP, { + workflowNodesMap: workflowNodesMap || {}, + availableVariables: workflowAvailableVariables || [], + }) }) - }, [editor, workflowNodesMap]) + }, [editor, workflowNodesMap, workflowAvailableVariables]) useEffect(() => { if (!editor.hasNodes([WorkflowVariableBlockNode])) @@ -40,7 +48,12 @@ const WorkflowVariableBlock = memo(({ editor.registerCommand( INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, (variables: string[]) => { - const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType) + const workflowVariableBlockNode = $createWorkflowVariableBlockNode( + variables, + workflowNodesMap, + getVarType, + workflowAvailableVariables || [], + ) $insertNodes([workflowVariableBlockNode]) if (onInsert) @@ -61,7 +74,7 @@ const WorkflowVariableBlock = memo(({ COMMAND_PRIORITY_EDITOR, ), ) - }, [editor, onInsert, onDelete, workflowNodesMap, getVarType]) + }, [editor, onInsert, onDelete, workflowNodesMap, getVarType, workflowAvailableVariables]) return null }) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx index 743937d8a6..2d13627b20 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx @@ -1,49 +1,55 @@ import type { LexicalNode, NodeKey, SerializedLexicalNode } from 'lexical' import type { GetVarType, WorkflowVariableBlockType } from '../../types' -import type { Var } from '@/app/components/workflow/types' +import type { NodeOutPutVar } from '@/app/components/workflow/types' import { DecoratorNode } from 'lexical' import WorkflowVariableBlockComponent from './component' -export type WorkflowNodesMap = WorkflowVariableBlockType['workflowNodesMap'] +export type WorkflowNodesMap = NonNullable type SerializedNode = SerializedLexicalNode & { variables: string[] workflowNodesMap: WorkflowNodesMap getVarType?: GetVarType - environmentVariables?: Var[] - conversationVariables?: Var[] - ragVariables?: Var[] + availableVariables?: NodeOutPutVar[] } export class WorkflowVariableBlockNode extends DecoratorNode { __variables: string[] __workflowNodesMap: WorkflowNodesMap __getVarType?: GetVarType - __environmentVariables?: Var[] - __conversationVariables?: Var[] - __ragVariables?: Var[] + __availableVariables?: NodeOutPutVar[] static getType(): string { return 'workflow-variable-block' } static clone(node: WorkflowVariableBlockNode): WorkflowVariableBlockNode { - return new WorkflowVariableBlockNode(node.__variables, node.__workflowNodesMap, node.__getVarType, node.__key, node.__environmentVariables, node.__conversationVariables, node.__ragVariables) + return new WorkflowVariableBlockNode( + node.__variables, + node.__workflowNodesMap, + node.__getVarType, + node.__key, + node.__availableVariables, + ) } isInline(): boolean { return true } - constructor(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType: any, key?: NodeKey, environmentVariables?: Var[], conversationVariables?: Var[], ragVariables?: Var[]) { + constructor( + variables: string[], + workflowNodesMap: WorkflowNodesMap, + getVarType: any, + key?: NodeKey, + availableVariables?: NodeOutPutVar[], + ) { super(key) this.__variables = variables this.__workflowNodesMap = workflowNodesMap this.__getVarType = getVarType - this.__environmentVariables = environmentVariables - this.__conversationVariables = conversationVariables - this.__ragVariables = ragVariables + this.__availableVariables = availableVariables } createDOM(): HTMLElement { @@ -63,30 +69,34 @@ export class WorkflowVariableBlockNode extends DecoratorNode variables={this.__variables} workflowNodesMap={this.__workflowNodesMap} getVarType={this.__getVarType!} - environmentVariables={this.__environmentVariables} - conversationVariables={this.__conversationVariables} - ragVariables={this.__ragVariables} + availableVariables={this.__availableVariables} /> ) } static importJSON(serializedNode: SerializedNode): WorkflowVariableBlockNode { - const node = $createWorkflowVariableBlockNode(serializedNode.variables, serializedNode.workflowNodesMap, serializedNode.getVarType, serializedNode.environmentVariables, serializedNode.conversationVariables, serializedNode.ragVariables) + const node = $createWorkflowVariableBlockNode( + serializedNode.variables, + serializedNode.workflowNodesMap, + serializedNode.getVarType, + serializedNode.availableVariables, + ) return node } exportJSON(): SerializedNode { - return { + const json: SerializedNode = { type: 'workflow-variable-block', version: 1, variables: this.getVariables(), workflowNodesMap: this.getWorkflowNodesMap(), getVarType: this.getVarType(), - environmentVariables: this.getEnvironmentVariables(), - conversationVariables: this.getConversationVariables(), - ragVariables: this.getRagVariables(), } + if (this.getAvailableVariables()) + json.availableVariables = this.getAvailableVariables() + + return json } getVariables(): string[] { @@ -104,27 +114,28 @@ export class WorkflowVariableBlockNode extends DecoratorNode return self.__getVarType } - getEnvironmentVariables(): any { + getAvailableVariables(): NodeOutPutVar[] | undefined { const self = this.getLatest() - return self.__environmentVariables - } - - getConversationVariables(): any { - const self = this.getLatest() - return self.__conversationVariables - } - - getRagVariables(): any { - const self = this.getLatest() - return self.__ragVariables + return self.__availableVariables } getTextContent(): string { return `{{#${this.getVariables().join('.')}#}}` } } -export function $createWorkflowVariableBlockNode(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType?: GetVarType, environmentVariables?: Var[], conversationVariables?: Var[], ragVariables?: Var[]): WorkflowVariableBlockNode { - return new WorkflowVariableBlockNode(variables, workflowNodesMap, getVarType, undefined, environmentVariables, conversationVariables, ragVariables) +export function $createWorkflowVariableBlockNode( + variables: string[], + workflowNodesMap: WorkflowNodesMap, + getVarType?: GetVarType, + availableVariables?: NodeOutPutVar[], +): WorkflowVariableBlockNode { + return new WorkflowVariableBlockNode( + variables, + workflowNodesMap, + getVarType, + undefined, + availableVariables, + ) } export function $isWorkflowVariableBlockNode( diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx index 573c97f465..e3c947d786 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx @@ -15,19 +15,12 @@ import { WorkflowVariableBlockNode } from './index' import { $createWorkflowVariableBlockNode } from './node' const WorkflowVariableBlockReplacementBlock = ({ - workflowNodesMap, + workflowNodesMap = {}, getVarType, onInsert, variables, }: WorkflowVariableBlockType) => { const [editor] = useLexicalComposerContext() - const ragVariables = variables?.reduce((acc, curr) => { - if (curr.nodeId === 'rag') - acc.push(...curr.vars) - else - acc.push(...curr.vars.filter(v => v.isRagVariable)) - return acc - }, []) useEffect(() => { if (!editor.hasNodes([WorkflowVariableBlockNode])) @@ -39,8 +32,13 @@ const WorkflowVariableBlockReplacementBlock = ({ onInsert() const nodePathString = textNode.getTextContent().slice(3, -3) - return $applyNodeReplacement($createWorkflowVariableBlockNode(nodePathString.split('.'), workflowNodesMap, getVarType, variables?.find(o => o.nodeId === 'env')?.vars || [], variables?.find(o => o.nodeId === 'conversation')?.vars || [], ragVariables)) - }, [onInsert, workflowNodesMap, getVarType, variables, ragVariables]) + return $applyNodeReplacement($createWorkflowVariableBlockNode( + nodePathString.split('.'), + workflowNodesMap, + getVarType, + variables || [], + )) + }, [onInsert, workflowNodesMap, getVarType, variables]) const getMatch = useCallback((text: string) => { const matchArr = REGEX.exec(text) diff --git a/web/app/components/plugins/__tests__/constants.spec.ts b/web/app/components/plugins/__tests__/constants.spec.ts new file mode 100644 index 0000000000..d3ec02c76c --- /dev/null +++ b/web/app/components/plugins/__tests__/constants.spec.ts @@ -0,0 +1,40 @@ +import { describe, expect, it } from 'vitest' +import { categoryKeys, tagKeys } from '../constants' +import { PluginCategoryEnum } from '../types' + +describe('plugin constants', () => { + it('exposes the expected plugin tag keys', () => { + expect(tagKeys).toEqual([ + 'agent', + 'rag', + 'search', + 'image', + 'videos', + 'weather', + 'finance', + 'design', + 'travel', + 'social', + 'news', + 'medical', + 'productivity', + 'education', + 'business', + 'entertainment', + 'utilities', + 'other', + ]) + }) + + it('exposes the expected category keys in display order', () => { + expect(categoryKeys).toEqual([ + PluginCategoryEnum.model, + PluginCategoryEnum.tool, + PluginCategoryEnum.datasource, + PluginCategoryEnum.agent, + PluginCategoryEnum.extension, + 'bundle', + PluginCategoryEnum.trigger, + ]) + }) +}) diff --git a/web/app/components/plugins/__tests__/provider-card.spec.tsx b/web/app/components/plugins/__tests__/provider-card.spec.tsx new file mode 100644 index 0000000000..71efd86bb0 --- /dev/null +++ b/web/app/components/plugins/__tests__/provider-card.spec.tsx @@ -0,0 +1,104 @@ +import type { Plugin } from '../types' +import { fireEvent, render, screen } from '@testing-library/react' +import { ThemeProvider } from 'next-themes' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import ProviderCard from '../provider-card' +import { PluginCategoryEnum } from '../types' + +vi.mock('@/context/i18n', () => ({ + useLocale: () => 'en-US', +})) + +vi.mock('@/hooks/use-i18n', () => ({ + useRenderI18nObject: () => (value: Record) => value['en-US'] || value.en_US, +})) + +vi.mock('@/app/components/plugins/install-plugin/install-from-marketplace', () => ({ + default: ({ onClose }: { onClose: () => void }) => ( +
+ +
+ ), +})) + +vi.mock('@/app/components/plugins/marketplace/utils', () => ({ + getPluginLinkInMarketplace: (plugin: Plugin, params: Record) => + `/marketplace/${plugin.org}/${plugin.name}?language=${params.language}&theme=${params.theme}`, +})) + +vi.mock('../card/base/card-icon', () => ({ + default: ({ src }: { src: string }) =>
{src}
, +})) + +vi.mock('../card/base/description', () => ({ + default: ({ text }: { text: string }) =>
{text}
, +})) + +vi.mock('../card/base/download-count', () => ({ + default: ({ downloadCount }: { downloadCount: number }) =>
{downloadCount}
, +})) + +vi.mock('../card/base/title', () => ({ + default: ({ title }: { title: string }) =>
{title}
, +})) + +const payload = { + type: 'plugin', + org: 'dify', + name: 'provider-one', + plugin_id: 'provider-one', + version: '1.0.0', + latest_version: '1.0.0', + latest_package_identifier: 'pkg-1', + icon: 'icon.png', + verified: true, + label: { 'en-US': 'Provider One' }, + brief: { 'en-US': 'Provider description' }, + description: { 'en-US': 'Full description' }, + introduction: 'Intro', + repository: 'https://github.com/dify/provider-one', + category: PluginCategoryEnum.tool, + install_count: 123, + endpoint: { settings: [] }, + tags: [{ name: 'search' }, { name: 'rag' }], + badges: [], + verification: { authorized_category: 'community' }, + from: 'marketplace', +} as Plugin + +describe('ProviderCard', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + const renderProviderCard = () => render( + + + , + ) + + it('renders provider information, tags, and detail link', () => { + renderProviderCard() + + expect(screen.getByTestId('title')).toHaveTextContent('Provider One') + expect(screen.getByText('dify')).toBeInTheDocument() + expect(screen.getByTestId('download-count')).toHaveTextContent('123') + expect(screen.getByTestId('description')).toHaveTextContent('Provider description') + expect(screen.getByText('search')).toBeInTheDocument() + expect(screen.getByText('rag')).toBeInTheDocument() + expect(screen.getByRole('link', { name: /plugin.detailPanel.operation.detail/i })).toHaveAttribute( + 'href', + '/marketplace/dify/provider-one?language=en-US&theme=system', + ) + }) + + it('opens and closes the install modal', () => { + renderProviderCard() + + fireEvent.click(screen.getByRole('button', { name: /plugin.detailPanel.operation.install/i })) + expect(screen.getByTestId('install-modal')).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('close-install-modal')) + expect(screen.queryByTestId('install-modal')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/install-plugin/base/__tests__/use-get-icon.spec.ts b/web/app/components/plugins/install-plugin/base/__tests__/use-get-icon.spec.ts new file mode 100644 index 0000000000..c5364ec47f --- /dev/null +++ b/web/app/components/plugins/install-plugin/base/__tests__/use-get-icon.spec.ts @@ -0,0 +1,22 @@ +import { renderHook } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import useGetIcon from '../use-get-icon' + +vi.mock('@/config', () => ({ + API_PREFIX: 'https://api.example.com', +})) + +vi.mock('@/context/app-context', () => ({ + useSelector: (selector: (state: { currentWorkspace: { id: string } }) => string | { id: string }) => + selector({ currentWorkspace: { id: 'workspace-123' } }), +})) + +describe('useGetIcon', () => { + it('builds icon url with current workspace id', () => { + const { result } = renderHook(() => useGetIcon()) + + expect(result.current.getIconUrl('plugin-icon.png')).toBe( + 'https://api.example.com/workspaces/current/plugin/icon?tenant_id=workspace-123&filename=plugin-icon.png', + ) + }) +}) diff --git a/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/github-item.spec.tsx b/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/github-item.spec.tsx new file mode 100644 index 0000000000..12cd89765a --- /dev/null +++ b/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/github-item.spec.tsx @@ -0,0 +1,136 @@ +import type { GitHubItemAndMarketPlaceDependency, Plugin } from '../../../../types' +import type { VersionProps } from '@/app/components/plugins/types' +import { render, screen, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import GithubItem from '../github-item' + +const mockUseUploadGitHub = vi.fn() +const mockPluginManifestToCardPluginProps = vi.fn() +const mockLoadedItem = vi.fn() + +vi.mock('@/service/use-plugins', () => ({ + useUploadGitHub: (params: { repo: string, version: string, package: string }) => mockUseUploadGitHub(params), +})) + +vi.mock('../../../utils', () => ({ + pluginManifestToCardPluginProps: (manifest: unknown) => mockPluginManifestToCardPluginProps(manifest), +})) + +vi.mock('../../../base/loading', () => ({ + default: () =>
loading
, +})) + +vi.mock('../loaded-item', () => ({ + default: (props: Record) => { + mockLoadedItem(props) + return
loaded-item
+ }, +})) + +const dependency: GitHubItemAndMarketPlaceDependency = { + type: 'github', + value: { + repo: 'dify/plugin', + release: 'v1.0.0', + package: 'plugin.zip', + }, +} + +const versionInfo: VersionProps = { + hasInstalled: false, + installedVersion: '', + toInstallVersion: '1.0.0', +} + +describe('GithubItem', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('renders loading state before payload is ready', () => { + mockUseUploadGitHub.mockReturnValue({ data: null, error: null }) + + render( + , + ) + + expect(screen.getByTestId('loading')).toBeInTheDocument() + expect(mockUseUploadGitHub).toHaveBeenCalledWith({ + repo: 'dify/plugin', + version: 'v1.0.0', + package: 'plugin.zip', + }) + }) + + it('converts fetched manifest and renders LoadedItem', async () => { + const onFetchedPayload = vi.fn() + const payload = { + plugin_id: 'plugin-1', + name: 'Plugin One', + org: 'dify', + icon: 'icon.png', + version: '1.0.0', + } as Plugin + + mockUseUploadGitHub.mockReturnValue({ + data: { + manifest: { name: 'manifest' }, + unique_identifier: 'plugin-1', + }, + error: null, + }) + mockPluginManifestToCardPluginProps.mockReturnValue(payload) + + render( + , + ) + + await waitFor(() => { + expect(onFetchedPayload).toHaveBeenCalledWith(payload) + expect(screen.getByTestId('loaded-item')).toBeInTheDocument() + }) + + expect(mockLoadedItem).toHaveBeenCalledWith(expect.objectContaining({ + checked: true, + versionInfo, + payload: expect.objectContaining({ + ...payload, + from: 'github', + }), + })) + }) + + it('reports fetch error from upload hook', async () => { + const onFetchError = vi.fn() + mockUseUploadGitHub.mockReturnValue({ data: null, error: new Error('boom') }) + + render( + , + ) + + await waitFor(() => { + expect(onFetchError).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/loaded-item.spec.tsx b/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/loaded-item.spec.tsx new file mode 100644 index 0000000000..d19331a4e4 --- /dev/null +++ b/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/loaded-item.spec.tsx @@ -0,0 +1,160 @@ +import type { Plugin } from '../../../../types' +import type { VersionProps } from '@/app/components/plugins/types' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import LoadedItem from '../loaded-item' + +const mockCheckbox = vi.fn() +const mockCard = vi.fn() +const mockVersion = vi.fn() +const mockUsePluginInstallLimit = vi.fn() + +vi.mock('@/config', () => ({ + API_PREFIX: 'https://api.example.com', + MARKETPLACE_API_PREFIX: 'https://marketplace.example.com', +})) + +vi.mock('@/app/components/base/checkbox', () => ({ + default: (props: { checked: boolean, disabled: boolean, onCheck: () => void }) => { + mockCheckbox(props) + return ( + + ) + }, +})) + +vi.mock('../../../../card', () => ({ + default: (props: { titleLeft?: React.ReactNode }) => { + mockCard(props) + return ( +
+ {props.titleLeft} +
+ ) + }, +})) + +vi.mock('../../../base/use-get-icon', () => ({ + default: () => ({ + getIconUrl: (icon: string) => `https://api.example.com/${icon}`, + }), +})) + +vi.mock('../../../base/version', () => ({ + default: (props: Record) => { + mockVersion(props) + return
version
+ }, +})) + +vi.mock('../../../hooks/use-install-plugin-limit', () => ({ + default: (payload: Plugin) => mockUsePluginInstallLimit(payload), +})) + +const payload = { + plugin_id: 'plugin-1', + org: 'dify', + name: 'Loaded Plugin', + icon: 'icon.png', + version: '1.0.0', +} as Plugin + +const versionInfo: VersionProps = { + hasInstalled: false, + installedVersion: '', + toInstallVersion: '0.9.0', +} + +describe('LoadedItem', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUsePluginInstallLimit.mockReturnValue({ canInstall: true }) + }) + + it('uses local icon url and forwards version title for non-marketplace plugins', () => { + render( + , + ) + + expect(screen.getByTestId('card')).toBeInTheDocument() + expect(mockUsePluginInstallLimit).toHaveBeenCalledWith(payload) + expect(mockCard).toHaveBeenCalledWith(expect.objectContaining({ + limitedInstall: false, + payload: expect.objectContaining({ + ...payload, + icon: 'https://api.example.com/icon.png', + }), + titleLeft: expect.anything(), + })) + expect(mockVersion).toHaveBeenCalledWith(expect.objectContaining({ + hasInstalled: false, + installedVersion: '', + toInstallVersion: '1.0.0', + })) + }) + + it('uses marketplace icon url and disables checkbox when install limit is reached', () => { + mockUsePluginInstallLimit.mockReturnValue({ canInstall: false }) + + render( + , + ) + + expect(screen.getByTestId('checkbox')).toBeDisabled() + expect(mockCard).toHaveBeenCalledWith(expect.objectContaining({ + limitedInstall: true, + payload: expect.objectContaining({ + icon: 'https://marketplace.example.com/plugins/dify/Loaded Plugin/icon', + }), + })) + }) + + it('calls onCheckedChange with payload when checkbox is toggled', () => { + const onCheckedChange = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByTestId('checkbox')) + + expect(onCheckedChange).toHaveBeenCalledWith(payload) + }) + + it('omits version badge when payload has no version', () => { + render( + , + ) + + expect(mockCard).toHaveBeenCalledWith(expect.objectContaining({ + titleLeft: null, + })) + }) +}) diff --git a/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/marketplace-item.spec.tsx b/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/marketplace-item.spec.tsx new file mode 100644 index 0000000000..b6c1763ac5 --- /dev/null +++ b/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/marketplace-item.spec.tsx @@ -0,0 +1,69 @@ +import type { Plugin } from '../../../../types' +import type { VersionProps } from '@/app/components/plugins/types' +import { render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import MarketPlaceItem from '../marketplace-item' + +const mockLoadedItem = vi.fn() + +vi.mock('../../../base/loading', () => ({ + default: () =>
loading
, +})) + +vi.mock('../loaded-item', () => ({ + default: (props: Record) => { + mockLoadedItem(props) + return
loaded-item
+ }, +})) + +const payload = { + plugin_id: 'plugin-1', + org: 'dify', + name: 'Marketplace Plugin', + icon: 'icon.png', +} as Plugin + +const versionInfo: VersionProps = { + hasInstalled: false, + installedVersion: '', + toInstallVersion: '1.0.0', +} + +describe('MarketPlaceItem', () => { + it('renders loading when payload is absent', () => { + render( + , + ) + + expect(screen.getByTestId('loading')).toBeInTheDocument() + }) + + it('renders LoadedItem with marketplace payload and version', () => { + render( + , + ) + + expect(screen.getByTestId('loaded-item')).toBeInTheDocument() + expect(mockLoadedItem).toHaveBeenCalledWith(expect.objectContaining({ + checked: true, + isFromMarketPlace: true, + versionInfo, + payload: expect.objectContaining({ + ...payload, + version: '2.0.0', + }), + })) + }) +}) diff --git a/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/package-item.spec.tsx b/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/package-item.spec.tsx new file mode 100644 index 0000000000..e92faeb77f --- /dev/null +++ b/web/app/components/plugins/install-plugin/install-bundle/item/__tests__/package-item.spec.tsx @@ -0,0 +1,124 @@ +import type { PackageDependency } from '../../../../types' +import type { VersionProps } from '@/app/components/plugins/types' +import { render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PluginCategoryEnum } from '../../../../types' +import PackageItem from '../package-item' + +const mockPluginManifestToCardPluginProps = vi.fn() +const mockLoadedItem = vi.fn() + +vi.mock('../../../utils', () => ({ + pluginManifestToCardPluginProps: (manifest: unknown) => mockPluginManifestToCardPluginProps(manifest), +})) + +vi.mock('../../../base/loading-error', () => ({ + default: () =>
loading-error
, +})) + +vi.mock('../loaded-item', () => ({ + default: (props: Record) => { + mockLoadedItem(props) + return
loaded-item
+ }, +})) + +const versionInfo: VersionProps = { + hasInstalled: false, + installedVersion: '', + toInstallVersion: '1.0.0', +} + +const payload = { + type: 'package', + value: { + manifest: { + plugin_unique_identifier: 'plugin-1', + version: '1.0.0', + author: 'dify', + icon: 'icon.png', + name: 'Package Plugin', + category: PluginCategoryEnum.tool, + label: { en_US: 'Package Plugin', zh_Hans: 'Package Plugin' }, + description: { en_US: 'Description', zh_Hans: 'Description' }, + created_at: '2024-01-01', + resource: {}, + plugins: [], + verified: true, + endpoint: { settings: [], endpoints: [] }, + model: null, + tags: [], + agent_strategy: null, + meta: { version: '1.0.0' }, + trigger: { + events: [], + identity: { + author: 'dify', + name: 'trigger', + description: { en_US: 'Trigger', zh_Hans: 'Trigger' }, + icon: 'icon.png', + label: { en_US: 'Trigger', zh_Hans: 'Trigger' }, + tags: [], + }, + subscription_constructor: { + credentials_schema: [], + oauth_schema: { + client_schema: [], + credentials_schema: [], + }, + parameters: [], + }, + subscription_schema: [], + }, + }, + }, +} as unknown as PackageDependency + +describe('PackageItem', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('renders loading error when manifest is missing', () => { + render( + , + ) + + expect(screen.getByTestId('loading-error')).toBeInTheDocument() + }) + + it('renders LoadedItem with converted plugin payload', () => { + mockPluginManifestToCardPluginProps.mockReturnValue({ + plugin_id: 'plugin-1', + name: 'Package Plugin', + org: 'dify', + icon: 'icon.png', + }) + + render( + , + ) + + expect(screen.getByTestId('loaded-item')).toBeInTheDocument() + expect(mockLoadedItem).toHaveBeenCalledWith(expect.objectContaining({ + checked: true, + isFromMarketPlace: true, + versionInfo, + payload: expect.objectContaining({ + plugin_id: 'plugin-1', + from: 'package', + }), + })) + }) +}) diff --git a/web/app/components/plugins/install-plugin/install-bundle/steps/__tests__/installed.spec.tsx b/web/app/components/plugins/install-plugin/install-bundle/steps/__tests__/installed.spec.tsx new file mode 100644 index 0000000000..9ae67b7d16 --- /dev/null +++ b/web/app/components/plugins/install-plugin/install-bundle/steps/__tests__/installed.spec.tsx @@ -0,0 +1,114 @@ +import type { InstallStatus, Plugin } from '../../../../types' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import Installed from '../installed' + +const mockCard = vi.fn() + +vi.mock('@/config', () => ({ + API_PREFIX: 'https://api.example.com', + MARKETPLACE_API_PREFIX: 'https://marketplace.example.com', +})) + +vi.mock('@/app/components/plugins/card', () => ({ + default: (props: { titleLeft?: React.ReactNode }) => { + mockCard(props) + return ( +
+ {props.titleLeft} +
+ ) + }, +})) + +vi.mock('../../../base/use-get-icon', () => ({ + default: () => ({ + getIconUrl: (icon: string) => `https://api.example.com/${icon}`, + }), +})) + +const plugins = [ + { + plugin_id: 'plugin-1', + org: 'dify', + name: 'Plugin One', + icon: 'icon-1.png', + version: '1.0.0', + }, + { + plugin_id: 'plugin-2', + org: 'dify', + name: 'Plugin Two', + icon: 'icon-2.png', + version: '2.0.0', + }, +] as Plugin[] + +const installStatus: InstallStatus[] = [ + { success: true, isFromMarketPlace: true }, + { success: false, isFromMarketPlace: false }, +] + +describe('Installed', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('renders plugin cards with install status and marketplace icon handling', () => { + render( + , + ) + + expect(screen.getAllByTestId('card')).toHaveLength(2) + expect(screen.getByRole('button', { name: 'common.operation.close' })).toBeInTheDocument() + expect(screen.getByText('1.0.0')).toBeInTheDocument() + expect(screen.getByText('2.0.0')).toBeInTheDocument() + expect(mockCard).toHaveBeenNthCalledWith(1, expect.objectContaining({ + installed: true, + installFailed: false, + payload: expect.objectContaining({ + icon: 'https://marketplace.example.com/plugins/dify/Plugin One/icon', + }), + })) + expect(mockCard).toHaveBeenNthCalledWith(2, expect.objectContaining({ + installed: false, + installFailed: true, + payload: expect.objectContaining({ + icon: 'https://api.example.com/icon-2.png', + }), + })) + }) + + it('calls onCancel when close button is clicked', () => { + const onCancel = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.close' })) + + expect(onCancel).toHaveBeenCalledTimes(1) + }) + + it('hides action button when isHideButton is true', () => { + render( + , + ) + + expect(screen.queryByRole('button', { name: 'common.operation.close' })).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/marketplace/__tests__/constants.spec.ts b/web/app/components/plugins/marketplace/__tests__/constants.spec.ts new file mode 100644 index 0000000000..cb3c822993 --- /dev/null +++ b/web/app/components/plugins/marketplace/__tests__/constants.spec.ts @@ -0,0 +1,37 @@ +import { describe, expect, it } from 'vitest' +import { PluginCategoryEnum } from '../../types' +import { + DEFAULT_SORT, + PLUGIN_CATEGORY_WITH_COLLECTIONS, + PLUGIN_TYPE_SEARCH_MAP, + SCROLL_BOTTOM_THRESHOLD, +} from '../constants' + +describe('marketplace constants', () => { + it('defines the expected default sort', () => { + expect(DEFAULT_SORT).toEqual({ + sortBy: 'install_count', + sortOrder: 'DESC', + }) + }) + + it('defines the expected plugin search type map', () => { + expect(PLUGIN_TYPE_SEARCH_MAP).toEqual({ + all: 'all', + model: PluginCategoryEnum.model, + tool: PluginCategoryEnum.tool, + agent: PluginCategoryEnum.agent, + extension: PluginCategoryEnum.extension, + datasource: PluginCategoryEnum.datasource, + trigger: PluginCategoryEnum.trigger, + bundle: 'bundle', + }) + expect(SCROLL_BOTTOM_THRESHOLD).toBe(100) + }) + + it('tracks only collection-backed categories', () => { + expect(PLUGIN_CATEGORY_WITH_COLLECTIONS.has(PLUGIN_TYPE_SEARCH_MAP.all)).toBe(true) + expect(PLUGIN_CATEGORY_WITH_COLLECTIONS.has(PLUGIN_TYPE_SEARCH_MAP.tool)).toBe(true) + expect(PLUGIN_CATEGORY_WITH_COLLECTIONS.has(PLUGIN_TYPE_SEARCH_MAP.model)).toBe(false) + }) +}) diff --git a/web/app/components/plugins/marketplace/__tests__/search-params.spec.ts b/web/app/components/plugins/marketplace/__tests__/search-params.spec.ts new file mode 100644 index 0000000000..c13a4528fb --- /dev/null +++ b/web/app/components/plugins/marketplace/__tests__/search-params.spec.ts @@ -0,0 +1,18 @@ +import { describe, expect, it } from 'vitest' +import { PLUGIN_TYPE_SEARCH_MAP } from '../constants' +import { marketplaceSearchParamsParsers } from '../search-params' + +describe('marketplace search params', () => { + it('applies the expected default values', () => { + expect(marketplaceSearchParamsParsers.category.parseServerSide(undefined)).toBe(PLUGIN_TYPE_SEARCH_MAP.all) + expect(marketplaceSearchParamsParsers.q.parseServerSide(undefined)).toBe('') + expect(marketplaceSearchParamsParsers.tags.parseServerSide(undefined)).toEqual([]) + }) + + it('parses supported query values with the configured parsers', () => { + expect(marketplaceSearchParamsParsers.category.parseServerSide(PLUGIN_TYPE_SEARCH_MAP.tool)).toBe(PLUGIN_TYPE_SEARCH_MAP.tool) + expect(marketplaceSearchParamsParsers.category.parseServerSide('unsupported')).toBe(PLUGIN_TYPE_SEARCH_MAP.all) + expect(marketplaceSearchParamsParsers.q.parseServerSide('keyword')).toBe('keyword') + expect(marketplaceSearchParamsParsers.tags.parseServerSide('rag,search')).toEqual(['rag', 'search']) + }) +}) diff --git a/web/app/components/plugins/marketplace/empty/__tests__/line.spec.tsx b/web/app/components/plugins/marketplace/empty/__tests__/line.spec.tsx new file mode 100644 index 0000000000..56e7046dae --- /dev/null +++ b/web/app/components/plugins/marketplace/empty/__tests__/line.spec.tsx @@ -0,0 +1,30 @@ +import { render } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import Line from '../line' + +const mockUseTheme = vi.fn() + +vi.mock('@/hooks/use-theme', () => ({ + default: () => mockUseTheme(), +})) + +describe('Line', () => { + it('renders dark mode svg variant', () => { + mockUseTheme.mockReturnValue({ theme: 'dark' }) + const { container } = render() + const svg = container.querySelector('svg') + + expect(svg).toHaveAttribute('height', '240') + expect(svg).toHaveAttribute('viewBox', '0 0 2 240') + expect(svg).toHaveClass('divider') + }) + + it('renders light mode svg variant', () => { + mockUseTheme.mockReturnValue({ theme: 'light' }) + const { container } = render() + const svg = container.querySelector('svg') + + expect(svg).toHaveAttribute('height', '241') + expect(svg).toHaveAttribute('viewBox', '0 0 2 241') + }) +}) diff --git a/web/app/components/plugins/marketplace/list/__tests__/card-wrapper.spec.tsx b/web/app/components/plugins/marketplace/list/__tests__/card-wrapper.spec.tsx new file mode 100644 index 0000000000..f1e263b6f6 --- /dev/null +++ b/web/app/components/plugins/marketplace/list/__tests__/card-wrapper.spec.tsx @@ -0,0 +1,115 @@ +import type { ComponentProps } from 'react' +import type { Plugin } from '@/app/components/plugins/types' +import { fireEvent, render, screen } from '@testing-library/react' +import { ThemeProvider } from 'next-themes' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PluginCategoryEnum } from '@/app/components/plugins/types' +import CardWrapper from '../card-wrapper' + +vi.mock('#i18n', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key, + }), + useLocale: () => 'en-US', +})) + +vi.mock('@/app/components/plugins/hooks', () => ({ + useTags: () => ({ + getTagLabel: (name: string) => `tag:${name}`, + }), +})) + +vi.mock('@/app/components/plugins/card', () => ({ + default: ({ payload, footer }: { payload: Plugin, footer?: React.ReactNode }) => ( +
+ {payload.name} + {footer} +
+ ), +})) + +vi.mock('@/app/components/plugins/card/card-more-info', () => ({ + default: ({ downloadCount, tags }: { downloadCount: number, tags: string[] }) => ( +
+ {downloadCount} + : + {tags.join('|')} +
+ ), +})) + +vi.mock('@/app/components/plugins/install-plugin/install-from-marketplace', () => ({ + default: ({ onClose }: { onClose: () => void }) => ( +
+ +
+ ), +})) + +vi.mock('../../utils', () => ({ + getPluginDetailLinkInMarketplace: (plugin: Plugin) => `/detail/${plugin.org}/${plugin.name}`, + getPluginLinkInMarketplace: (plugin: Plugin, params: Record) => `/marketplace/${plugin.org}/${plugin.name}?language=${params.language}&theme=${params.theme}`, +})) + +const plugin = { + type: 'plugin', + org: 'dify', + name: 'plugin-a', + plugin_id: 'plugin-a', + version: '1.0.0', + latest_version: '1.0.0', + latest_package_identifier: 'pkg', + icon: 'icon.png', + verified: true, + label: { 'en-US': 'Plugin A' }, + brief: { 'en-US': 'Brief' }, + description: { 'en-US': 'Description' }, + introduction: 'Intro', + repository: 'https://github.com/dify/plugin-a', + category: PluginCategoryEnum.tool, + install_count: 42, + endpoint: { settings: [] }, + tags: [{ name: 'search' }, { name: 'agent' }], + badges: [], + verification: { authorized_category: 'community' }, + from: 'marketplace', +} as Plugin + +describe('CardWrapper', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + const renderCardWrapper = (props: Partial> = {}) => render( + + + , + ) + + it('renders plugin detail link when install button is hidden', () => { + renderCardWrapper() + + expect(screen.getByRole('link')).toHaveAttribute('href', '/detail/dify/plugin-a') + expect(screen.getByTestId('card-more-info')).toHaveTextContent('42:tag:search|tag:agent') + }) + + it('renders install and marketplace detail actions when install button is shown', () => { + renderCardWrapper({ showInstallButton: true }) + + expect(screen.getByRole('button', { name: 'plugin.detailPanel.operation.install' })).toBeInTheDocument() + expect(screen.getByRole('link', { name: 'plugin.detailPanel.operation.detail' })).toHaveAttribute( + 'href', + '/marketplace/dify/plugin-a?language=en-US&theme=system', + ) + }) + + it('opens and closes install modal from install action', () => { + renderCardWrapper({ showInstallButton: true }) + + fireEvent.click(screen.getByRole('button', { name: 'plugin.detailPanel.operation.install' })) + expect(screen.getByTestId('install-modal')).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('close-install-modal')) + expect(screen.queryByTestId('install-modal')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/marketplace/list/__tests__/list-with-collection.spec.tsx b/web/app/components/plugins/marketplace/list/__tests__/list-with-collection.spec.tsx new file mode 100644 index 0000000000..cbaf7868a0 --- /dev/null +++ b/web/app/components/plugins/marketplace/list/__tests__/list-with-collection.spec.tsx @@ -0,0 +1,102 @@ +import type { MarketplaceCollection } from '../../types' +import type { Plugin } from '@/app/components/plugins/types' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import ListWithCollection from '../list-with-collection' + +const mockMoreClick = vi.fn() + +vi.mock('#i18n', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key, + }), + useLocale: () => 'en-US', +})) + +vi.mock('../../atoms', () => ({ + useMarketplaceMoreClick: () => mockMoreClick, +})) + +vi.mock('@/i18n-config/language', () => ({ + getLanguage: (locale: string) => locale, +})) + +vi.mock('../card-wrapper', () => ({ + default: ({ plugin }: { plugin: Plugin }) =>
{plugin.name}
, +})) + +const collections: MarketplaceCollection[] = [ + { + name: 'featured', + label: { 'en-US': 'Featured' }, + description: { 'en-US': 'Featured plugins' }, + rule: 'featured', + created_at: '', + updated_at: '', + searchable: true, + search_params: { query: 'featured' }, + }, + { + name: 'empty', + label: { 'en-US': 'Empty' }, + description: { 'en-US': 'No plugins' }, + rule: 'empty', + created_at: '', + updated_at: '', + searchable: false, + search_params: {}, + }, +] + +const pluginsMap: Record = { + featured: [ + { plugin_id: 'p1', name: 'Plugin One' }, + { plugin_id: 'p2', name: 'Plugin Two' }, + ] as Plugin[], + empty: [], +} + +describe('ListWithCollection', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('renders only collections that contain plugins', () => { + render( + , + ) + + expect(screen.getByText('Featured')).toBeInTheDocument() + expect(screen.queryByText('Empty')).not.toBeInTheDocument() + expect(screen.getAllByTestId('card-wrapper')).toHaveLength(2) + }) + + it('calls more handler for searchable collection', () => { + render( + , + ) + + fireEvent.click(screen.getByText('plugin.marketplace.viewMore')) + + expect(mockMoreClick).toHaveBeenCalledWith({ query: 'featured' }) + }) + + it('uses custom card renderer when provided', () => { + render( +
{plugin.name}
} + />, + ) + + expect(screen.getAllByTestId('custom-card')).toHaveLength(2) + expect(screen.queryByTestId('card-wrapper')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/marketplace/list/__tests__/list-wrapper.spec.tsx b/web/app/components/plugins/marketplace/list/__tests__/list-wrapper.spec.tsx new file mode 100644 index 0000000000..fecfea3007 --- /dev/null +++ b/web/app/components/plugins/marketplace/list/__tests__/list-wrapper.spec.tsx @@ -0,0 +1,92 @@ +import type { MarketplaceCollection } from '../../types' +import type { Plugin } from '@/app/components/plugins/types' +import { render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import ListWrapper from '../list-wrapper' + +const mockMarketplaceData = vi.hoisted(() => ({ + plugins: undefined as Plugin[] | undefined, + pluginsTotal: 0, + marketplaceCollections: [] as MarketplaceCollection[], + marketplaceCollectionPluginsMap: {} as Record, + isLoading: false, + isFetchingNextPage: false, + page: 1, +})) + +vi.mock('#i18n', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string, num?: number }) => + key === 'marketplace.pluginsResult' && options?.ns === 'plugin' + ? `${options.num} plugins found` + : options?.ns ? `${options.ns}.${key}` : key, + }), +})) + +vi.mock('../../state', () => ({ + useMarketplaceData: () => mockMarketplaceData, +})) + +vi.mock('@/app/components/base/loading', () => ({ + default: ({ className }: { className?: string }) =>
loading
, +})) + +vi.mock('../../sort-dropdown', () => ({ + default: () =>
sort
, +})) + +vi.mock('../index', () => ({ + default: ({ plugins }: { plugins?: Plugin[] }) =>
{plugins?.length ?? 'collections'}
, +})) + +describe('ListWrapper', () => { + beforeEach(() => { + vi.clearAllMocks() + mockMarketplaceData.plugins = undefined + mockMarketplaceData.pluginsTotal = 0 + mockMarketplaceData.marketplaceCollections = [] + mockMarketplaceData.marketplaceCollectionPluginsMap = {} + mockMarketplaceData.isLoading = false + mockMarketplaceData.isFetchingNextPage = false + mockMarketplaceData.page = 1 + }) + + it('shows result header and sort dropdown when plugins are loaded', () => { + mockMarketplaceData.plugins = [{ plugin_id: 'p1', name: 'Plugin One' } as Plugin] + mockMarketplaceData.pluginsTotal = 1 + + render() + + expect(screen.getByText('1 plugins found')).toBeInTheDocument() + expect(screen.getByTestId('sort-dropdown')).toBeInTheDocument() + }) + + it('shows centered loading only on initial loading page', () => { + mockMarketplaceData.isLoading = true + mockMarketplaceData.page = 1 + + render() + + expect(screen.getByTestId('loading')).toBeInTheDocument() + expect(screen.queryByTestId('list')).not.toBeInTheDocument() + }) + + it('renders list when loading additional pages', () => { + mockMarketplaceData.isLoading = true + mockMarketplaceData.page = 2 + mockMarketplaceData.plugins = [{ plugin_id: 'p1', name: 'Plugin One' } as Plugin] + + render() + + expect(screen.getByTestId('list')).toBeInTheDocument() + }) + + it('shows bottom loading indicator while fetching next page', () => { + mockMarketplaceData.plugins = [{ plugin_id: 'p1', name: 'Plugin One' } as Plugin] + mockMarketplaceData.isFetchingNextPage = true + + render() + + expect(screen.getAllByTestId('loading')).toHaveLength(1) + }) +}) diff --git a/web/app/components/plugins/marketplace/search-box/__tests__/search-box-wrapper.spec.tsx b/web/app/components/plugins/marketplace/search-box/__tests__/search-box-wrapper.spec.tsx new file mode 100644 index 0000000000..4a3b880c27 --- /dev/null +++ b/web/app/components/plugins/marketplace/search-box/__tests__/search-box-wrapper.spec.tsx @@ -0,0 +1,43 @@ +import { render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import SearchBoxWrapper from '../search-box-wrapper' + +const mockHandleSearchPluginTextChange = vi.fn() +const mockHandleFilterPluginTagsChange = vi.fn() +const mockSearchBox = vi.fn() + +vi.mock('#i18n', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key, + }), +})) + +vi.mock('../../atoms', () => ({ + useSearchPluginText: () => ['plugin search', mockHandleSearchPluginTextChange], + useFilterPluginTags: () => [['agent', 'rag'], mockHandleFilterPluginTagsChange], +})) + +vi.mock('../index', () => ({ + default: (props: Record) => { + mockSearchBox(props) + return
search-box
+ }, +})) + +describe('SearchBoxWrapper', () => { + it('passes marketplace search state into SearchBox', () => { + render() + + expect(screen.getByTestId('search-box')).toBeInTheDocument() + expect(mockSearchBox).toHaveBeenCalledWith(expect.objectContaining({ + wrapperClassName: 'z-11 mx-auto w-[640px] shrink-0', + inputClassName: 'w-full', + search: 'plugin search', + onSearchChange: mockHandleSearchPluginTextChange, + tags: ['agent', 'rag'], + onTagsChange: mockHandleFilterPluginTagsChange, + placeholder: 'plugin.searchPlugins', + usedInMarketplace: true, + })) + }) +}) diff --git a/web/app/components/plugins/marketplace/search-box/__tests__/tags-filter.spec.tsx b/web/app/components/plugins/marketplace/search-box/__tests__/tags-filter.spec.tsx new file mode 100644 index 0000000000..bb5d8e734c --- /dev/null +++ b/web/app/components/plugins/marketplace/search-box/__tests__/tags-filter.spec.tsx @@ -0,0 +1,126 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import TagsFilter from '../tags-filter' + +vi.mock('#i18n', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key, + }), +})) + +vi.mock('@/app/components/plugins/hooks', () => ({ + useTags: () => ({ + tags: [ + { name: 'agent', label: 'Agent' }, + { name: 'rag', label: 'RAG' }, + { name: 'search', label: 'Search' }, + ], + tagsMap: { + agent: { name: 'agent', label: 'Agent' }, + rag: { name: 'rag', label: 'RAG' }, + search: { name: 'search', label: 'Search' }, + }, + }), +})) + +vi.mock('@/app/components/base/checkbox', () => ({ + default: ({ checked }: { checked: boolean }) => {String(checked)}, +})) + +vi.mock('@/app/components/base/input', () => ({ + default: ({ + value, + onChange, + placeholder, + }: { + value: string + onChange: (event: { target: { value: string } }) => void + placeholder: string + }) => ( + onChange({ target: { value: event.target.value } })} + /> + ), +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', async () => { + const React = await import('react') + return { + PortalToFollowElem: ({ children }: { children: React.ReactNode }) =>
{children}
, + PortalToFollowElemTrigger: ({ + children, + onClick, + }: { + children: React.ReactNode + onClick: () => void + }) => , + PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) =>
{children}
, + } +}) + +vi.mock('../trigger/marketplace', () => ({ + default: ({ selectedTagsLength }: { selectedTagsLength: number }) => ( +
+ marketplace: + {selectedTagsLength} +
+ ), +})) + +vi.mock('../trigger/tool-selector', () => ({ + default: ({ selectedTagsLength }: { selectedTagsLength: number }) => ( +
+ tool: + {selectedTagsLength} +
+ ), +})) + +describe('TagsFilter', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('renders marketplace trigger when used in marketplace', () => { + render() + + expect(screen.getByTestId('marketplace-trigger')).toHaveTextContent('marketplace:1') + expect(screen.queryByTestId('tool-trigger')).not.toBeInTheDocument() + }) + + it('renders tool selector trigger when used outside marketplace', () => { + render() + + expect(screen.getByTestId('tool-trigger')).toHaveTextContent('tool:1') + expect(screen.queryByTestId('marketplace-trigger')).not.toBeInTheDocument() + }) + + it('filters tag options by search text', () => { + render() + + expect(screen.getByText('Agent')).toBeInTheDocument() + expect(screen.getByText('RAG')).toBeInTheDocument() + expect(screen.getByText('Search')).toBeInTheDocument() + + fireEvent.change(screen.getByLabelText('tags-search'), { target: { value: 'ra' } }) + + expect(screen.queryByText('Agent')).not.toBeInTheDocument() + expect(screen.getByText('RAG')).toBeInTheDocument() + expect(screen.queryByText('Search')).not.toBeInTheDocument() + }) + + it('adds and removes selected tags when options are clicked', () => { + const onTagsChange = vi.fn() + const { rerender } = render() + + fireEvent.click(screen.getByText('Agent')) + expect(onTagsChange).toHaveBeenCalledWith([]) + + rerender() + fireEvent.click(screen.getByText('RAG')) + expect(onTagsChange).toHaveBeenCalledWith(['agent', 'rag']) + }) +}) diff --git a/web/app/components/plugins/marketplace/search-box/trigger/__tests__/marketplace.spec.tsx b/web/app/components/plugins/marketplace/search-box/trigger/__tests__/marketplace.spec.tsx new file mode 100644 index 0000000000..4d1a11ac00 --- /dev/null +++ b/web/app/components/plugins/marketplace/search-box/trigger/__tests__/marketplace.spec.tsx @@ -0,0 +1,67 @@ +import type { Tag } from '../../../../hooks' +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import MarketplaceTrigger from '../marketplace' + +vi.mock('#i18n', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key, + }), +})) + +const tagsMap: Record = { + agent: { name: 'agent', label: 'Agent' }, + rag: { name: 'rag', label: 'RAG' }, + search: { name: 'search', label: 'Search' }, +} + +describe('MarketplaceTrigger', () => { + it('shows all-tags text when no tags are selected', () => { + const { container } = render( + , + ) + + expect(screen.getByText('pluginTags.allTags')).toBeInTheDocument() + expect(container.querySelectorAll('svg').length).toBeGreaterThan(0) + expect(container.querySelectorAll('svg').length).toBe(2) + }) + + it('shows selected tag labels and overflow count', () => { + render( + , + ) + + expect(screen.getByText('Agent,RAG')).toBeInTheDocument() + expect(screen.getByText('+1')).toBeInTheDocument() + }) + + it('clears selected tags when clear icon is clicked', () => { + const onTagsChange = vi.fn() + + const { container } = render( + , + ) + + fireEvent.click(container.querySelectorAll('svg')[1]!) + + expect(onTagsChange).toHaveBeenCalledWith([]) + }) +}) diff --git a/web/app/components/plugins/marketplace/search-box/trigger/__tests__/tool-selector.spec.tsx b/web/app/components/plugins/marketplace/search-box/trigger/__tests__/tool-selector.spec.tsx new file mode 100644 index 0000000000..7e9069d61f --- /dev/null +++ b/web/app/components/plugins/marketplace/search-box/trigger/__tests__/tool-selector.spec.tsx @@ -0,0 +1,61 @@ +import type { Tag } from '../../../../hooks' +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import ToolSelectorTrigger from '../tool-selector' + +const tagsMap: Record = { + agent: { name: 'agent', label: 'Agent' }, + rag: { name: 'rag', label: 'RAG' }, + search: { name: 'search', label: 'Search' }, +} + +describe('ToolSelectorTrigger', () => { + it('renders only icon when no tags are selected', () => { + const { container } = render( + , + ) + + expect(container.querySelectorAll('svg')).toHaveLength(1) + expect(screen.queryByText('Agent')).not.toBeInTheDocument() + }) + + it('renders selected tag labels and overflow count', () => { + const { container } = render( + , + ) + + expect(screen.getByText('Agent,RAG')).toBeInTheDocument() + expect(screen.getByText('+1')).toBeInTheDocument() + expect(container.querySelectorAll('svg')).toHaveLength(2) + }) + + it('clears selected tags when clear icon is clicked', () => { + const onTagsChange = vi.fn() + + const { container } = render( + , + ) + + fireEvent.click(container.querySelectorAll('svg')[1]!) + + expect(onTagsChange).toHaveBeenCalledWith([]) + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/app-inputs-form.spec.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/app-inputs-form.spec.tsx new file mode 100644 index 0000000000..f3dcfeab5d --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/app-inputs-form.spec.tsx @@ -0,0 +1,106 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { InputVarType } from '@/app/components/workflow/types' +import AppInputsForm from '../app-inputs-form' + +vi.mock('@/app/components/base/file-uploader', () => ({ + FileUploaderInAttachmentWrapper: ({ + onChange, + }: { + onChange: (files: Array>) => void + }) => ( + + ), +})) + +vi.mock('@/app/components/base/select', () => ({ + PortalSelect: ({ + items, + onSelect, + }: { + items: Array<{ value: string, name: string }> + onSelect: (item: { value: string }) => void + }) => ( +
+ {items.map(item => ( + + ))} +
+ ), +})) + +describe('AppInputsForm', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should update text input values', () => { + const onFormChange = vi.fn() + const inputsRef = { current: { question: '' } } + + render( + , + ) + + fireEvent.change(screen.getByPlaceholderText('Question'), { + target: { value: 'hello' }, + }) + + expect(onFormChange).toHaveBeenCalledWith({ question: 'hello' }) + }) + + it('should update select values', () => { + const onFormChange = vi.fn() + const inputsRef = { current: { tone: '' } } + + render( + , + ) + + fireEvent.click(screen.getByTestId('select-formal')) + + expect(onFormChange).toHaveBeenCalledWith({ tone: 'formal' }) + }) + + it('should update uploaded single file values', () => { + const onFormChange = vi.fn() + const inputsRef = { current: { attachment: null } } + + render( + , + ) + + fireEvent.click(screen.getByTestId('file-uploader')) + + expect(onFormChange).toHaveBeenCalledWith({ + attachment: { id: 'file-1', name: 'demo.png' }, + }) + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/app-inputs-panel.spec.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/app-inputs-panel.spec.tsx new file mode 100644 index 0000000000..3e1c2a5a2a --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/app-inputs-panel.spec.tsx @@ -0,0 +1,87 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import AppInputsPanel from '../app-inputs-panel' + +let mockHookResult = { + inputFormSchema: [] as Array>, + isLoading: false, +} + +vi.mock('@/app/components/base/loading', () => ({ + default: () =>
Loading
, +})) + +vi.mock('@/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-form', () => ({ + default: ({ + onFormChange, + }: { + onFormChange: (value: Record) => void + }) => ( + + ), +})) + +vi.mock('@/app/components/plugins/plugin-detail-panel/app-selector/hooks/use-app-inputs-form-schema', () => ({ + useAppInputsFormSchema: () => mockHookResult, +})) + +describe('AppInputsPanel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockHookResult = { + inputFormSchema: [], + isLoading: false, + } + }) + + it('should render a loading state', () => { + mockHookResult = { + inputFormSchema: [], + isLoading: true, + } + + render( + , + ) + + expect(screen.getByTestId('loading')).toBeInTheDocument() + }) + + it('should render an empty state when no inputs are available', () => { + render( + , + ) + + expect(screen.getByText('app.appSelector.noParams')).toBeInTheDocument() + }) + + it('should render the inputs form and propagate changes', () => { + const onFormChange = vi.fn() + mockHookResult = { + inputFormSchema: [{ variable: 'topic' }], + isLoading: false, + } + + render( + , + ) + + fireEvent.click(screen.getByTestId('app-inputs-form')) + + expect(onFormChange).toHaveBeenCalledWith({ topic: 'updated' }) + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/app-picker.spec.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/app-picker.spec.tsx new file mode 100644 index 0000000000..a319d2f8c4 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/app-picker.spec.tsx @@ -0,0 +1,179 @@ +import type { ReactNode } from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeAll, beforeEach, describe, expect, it, vi } from 'vitest' +import { AppModeEnum } from '@/types/app' +import AppPicker from '../app-picker' + +class MockIntersectionObserver { + observe = vi.fn() + disconnect = vi.fn() + unobserve = vi.fn() +} + +class MockMutationObserver { + observe = vi.fn() + disconnect = vi.fn() + takeRecords = vi.fn().mockReturnValue([]) +} + +beforeAll(() => { + vi.stubGlobal('IntersectionObserver', MockIntersectionObserver) + vi.stubGlobal('MutationObserver', MockMutationObserver) +}) + +vi.mock('@/app/components/base/app-icon', () => ({ + default: () =>
, +})) + +vi.mock('@/app/components/base/input', () => ({ + default: ({ + value, + onChange, + onClear, + }: { + value: string + onChange: (e: { target: { value: string } }) => void + onClear?: () => void + }) => ( +
+ onChange({ target: { value: e.target.value } })} + /> + +
+ ), +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ + children, + open, + }: { + children: ReactNode + open: boolean + }) => ( +
+ {children} +
+ ), + PortalToFollowElemTrigger: ({ + children, + onClick, + }: { + children: ReactNode + onClick?: () => void + }) => ( + + ), + PortalToFollowElemContent: ({ children }: { children: ReactNode }) => ( +
{children}
+ ), +})) + +const apps = [ + { + id: 'app-1', + name: 'Chat App', + mode: AppModeEnum.CHAT, + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + }, + { + id: 'app-2', + name: 'Workflow App', + mode: AppModeEnum.WORKFLOW, + icon_type: 'emoji', + icon: '⚙️', + icon_background: '#fff', + }, +] + +describe('AppPicker', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should open when the trigger is clicked', () => { + const onShowChange = vi.fn() + + render( + Trigger} + isShow={false} + onShowChange={onShowChange} + onSelect={vi.fn()} + apps={apps as never} + isLoading={false} + hasMore={false} + onLoadMore={vi.fn()} + searchText="" + onSearchChange={vi.fn()} + />, + ) + + fireEvent.click(screen.getByTestId('picker-trigger')) + + expect(onShowChange).toHaveBeenCalledWith(true) + }) + + it('should render apps, select one, and handle search changes', () => { + const onSelect = vi.fn() + const onSearchChange = vi.fn() + + render( + Trigger} + isShow + onShowChange={vi.fn()} + onSelect={onSelect} + apps={apps as never} + isLoading={false} + hasMore={false} + onLoadMore={vi.fn()} + searchText="chat" + onSearchChange={onSearchChange} + />, + ) + + fireEvent.change(screen.getByTestId('search-input'), { + target: { value: 'workflow' }, + }) + fireEvent.click(screen.getByText('Workflow App')) + fireEvent.click(screen.getByTestId('clear-input')) + + expect(onSearchChange).toHaveBeenCalledWith('workflow') + expect(onSearchChange).toHaveBeenCalledWith('') + expect(onSelect).toHaveBeenCalledWith(apps[1]) + expect(screen.getByText('chat')).toBeInTheDocument() + }) + + it('should render loading text when loading more apps', () => { + render( + Trigger} + isShow + onShowChange={vi.fn()} + onSelect={vi.fn()} + apps={apps as never} + isLoading + hasMore + onLoadMore={vi.fn()} + searchText="" + onSearchChange={vi.fn()} + />, + ) + + expect(screen.getByText('common.loading')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/hooks/__tests__/use-app-inputs-form-schema.spec.ts b/web/app/components/plugins/plugin-detail-panel/app-selector/hooks/__tests__/use-app-inputs-form-schema.spec.ts new file mode 100644 index 0000000000..d6a5b03236 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/hooks/__tests__/use-app-inputs-form-schema.spec.ts @@ -0,0 +1,141 @@ +import { renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types' +import { AppModeEnum, Resolution } from '@/types/app' +import { useAppInputsFormSchema } from '../use-app-inputs-form-schema' + +let mockAppDetailData: Record | null = null +let mockAppWorkflowData: Record | null = null + +vi.mock('@/service/use-common', () => ({ + useFileUploadConfig: () => ({ + data: { + file_size_limit: 15, + image_file_size_limit: 10, + }, + }), +})) + +vi.mock('@/service/use-apps', () => ({ + useAppDetail: () => ({ + data: mockAppDetailData, + isFetching: false, + }), +})) + +vi.mock('@/service/use-workflow', () => ({ + useAppWorkflow: () => ({ + data: mockAppWorkflowData, + isFetching: false, + }), +})) + +describe('useAppInputsFormSchema', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAppDetailData = null + mockAppWorkflowData = null + }) + + it('should build basic app schemas and append image upload support', () => { + mockAppDetailData = { + id: 'app-1', + mode: AppModeEnum.COMPLETION, + model_config: { + user_input_form: [ + { + 'text-input': { + label: 'Question', + variable: 'question', + }, + }, + ], + file_upload: { + enabled: true, + image: { + enabled: true, + detail: Resolution.high, + number_limits: 2, + transfer_methods: ['local_file'], + }, + allowed_file_types: [SupportUploadFileTypes.image], + allowed_file_extensions: ['.png'], + allowed_file_upload_methods: ['local_file'], + number_limits: 2, + }, + }, + } + + const { result } = renderHook(() => useAppInputsFormSchema({ + appDetail: { + id: 'app-1', + mode: AppModeEnum.COMPLETION, + } as never, + })) + + expect(result.current.isLoading).toBe(false) + expect(result.current.inputFormSchema).toEqual(expect.arrayContaining([ + expect.objectContaining({ + variable: 'question', + type: 'text-input', + }), + expect.objectContaining({ + variable: '#image#', + type: InputVarType.singleFile, + allowed_file_extensions: ['.png'], + }), + ])) + }) + + it('should build workflow schemas from start node variables', () => { + mockAppDetailData = { + id: 'app-2', + mode: AppModeEnum.WORKFLOW, + } + mockAppWorkflowData = { + graph: { + nodes: [ + { + data: { + type: BlockEnum.Start, + variables: [ + { + label: 'Attachments', + variable: 'attachments', + type: InputVarType.multiFiles, + }, + ], + }, + }, + ], + }, + features: {}, + } + + const { result } = renderHook(() => useAppInputsFormSchema({ + appDetail: { + id: 'app-2', + mode: AppModeEnum.WORKFLOW, + } as never, + })) + + expect(result.current.inputFormSchema).toEqual([ + expect.objectContaining({ + variable: 'attachments', + type: InputVarType.multiFiles, + fileUploadConfig: expect.any(Object), + }), + ]) + }) + + it('should return an empty schema when app detail is unavailable', () => { + const { result } = renderHook(() => useAppInputsFormSchema({ + appDetail: { + id: 'missing-app', + mode: AppModeEnum.CHAT, + } as never, + })) + + expect(result.current.inputFormSchema).toEqual([]) + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header/__tests__/index.spec.tsx new file mode 100644 index 0000000000..27ef4e7eb3 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/__tests__/index.spec.tsx @@ -0,0 +1,251 @@ +import type { PluginDetail } from '@/app/components/plugins/types' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PluginCategoryEnum, PluginSource } from '@/app/components/plugins/types' +import DetailHeader from '../index' + +const mockSetTargetVersion = vi.fn() +const mockSetVersionPickerOpen = vi.fn() +const mockHandleUpdate = vi.fn() +const mockHandleUpdatedFromMarketplace = vi.fn() +const mockHandleDelete = vi.fn() + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + userProfile: { timezone: 'UTC' }, + }), +})) + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: () => 'en_US', + useLocale: () => 'en-US', +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ theme: 'light' }), +})) + +vi.mock('@/service/use-tools', () => ({ + useAllToolProviders: () => ({ + data: [{ + name: 'tool-plugin/provider-a', + type: 'builtin', + allow_delete: true, + }], + }), +})) + +vi.mock('@/utils/var', () => ({ + getMarketplaceUrl: (path: string) => `https://marketplace.example.com${path}`, +})) + +vi.mock('@/app/components/base/action-button', () => ({ + default: ({ onClick, children }: { onClick?: () => void, children: React.ReactNode }) => ( + + ), +})) + +vi.mock('@/app/components/base/button', () => ({ + default: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => ( + + ), +})) + +vi.mock('@/app/components/base/badge', () => ({ + default: ({ text, children }: { text?: React.ReactNode, children?: React.ReactNode }) => ( +
{text ?? children}
+ ), +})) + +vi.mock('@/app/components/base/ui/tooltip', () => ({ + Tooltip: ({ children }: { children: React.ReactNode }) =>
{children}
, + TooltipTrigger: ({ render }: { render: React.ReactNode }) => <>{render}, + TooltipContent: ({ children }: { children: React.ReactNode }) =>
{children}
, +})) + +vi.mock('@/app/components/plugins/plugin-auth', () => ({ + AuthCategory: { + tool: 'tool', + }, + PluginAuth: ({ pluginPayload }: { pluginPayload: { provider: string } }) => ( +
{pluginPayload.provider}
+ ), +})) + +vi.mock('@/app/components/plugins/plugin-detail-panel/operation-dropdown', () => ({ + default: ({ detailUrl }: { detailUrl: string }) =>
{detailUrl}
, +})) + +vi.mock('@/app/components/plugins/update-plugin/plugin-version-picker', () => ({ + default: ({ onSelect, trigger }: { + onSelect: (value: { version: string, unique_identifier: string, isDowngrade?: boolean }) => void + trigger: React.ReactNode + }) => ( +
+ {trigger} + +
+ ), +})) + +vi.mock('@/app/components/base/badges/verified', () => ({ + default: () =>
, +})) + +vi.mock('@/app/components/base/deprecation-notice', () => ({ + default: () =>
, +})) + +vi.mock('@/app/components/plugins/card/base/card-icon', () => ({ + default: ({ src }: { src: string }) =>
{src}
, +})) + +vi.mock('@/app/components/plugins/card/base/description', () => ({ + default: ({ text }: { text: string }) =>
{text}
, +})) + +vi.mock('@/app/components/plugins/card/base/org-info', () => ({ + default: ({ orgName }: { orgName: string }) =>
{orgName}
, +})) + +vi.mock('@/app/components/plugins/card/base/title', () => ({ + default: ({ title }: { title: string }) =>
{title}
, +})) + +vi.mock('@/app/components/plugins/plugin-page/use-reference-setting', () => ({ + default: () => ({ + referenceSetting: { + auto_upgrade: { + upgrade_time_of_day: 0, + }, + }, + }), +})) + +vi.mock('@/app/components/plugins/reference-setting-modal/auto-update-setting/utils', () => ({ + convertUTCDaySecondsToLocalSeconds: (value: number) => value, + timeOfDayToDayjs: () => ({ + format: () => '10:00 AM', + }), +})) + +vi.mock('../components', () => ({ + HeaderModals: () =>
, + PluginSourceBadge: ({ source }: { source: string }) =>
{source}
, +})) + +vi.mock('../hooks', () => ({ + useDetailHeaderState: () => ({ + modalStates: { + isShowUpdateModal: false, + showUpdateModal: vi.fn(), + hideUpdateModal: vi.fn(), + isShowPluginInfo: false, + showPluginInfo: vi.fn(), + hidePluginInfo: vi.fn(), + isShowDeleteConfirm: false, + showDeleteConfirm: vi.fn(), + hideDeleteConfirm: vi.fn(), + deleting: false, + showDeleting: vi.fn(), + hideDeleting: vi.fn(), + }, + versionPicker: { + isShow: false, + setIsShow: mockSetVersionPickerOpen, + targetVersion: { + version: '1.0.0', + unique_identifier: 'uid-1', + }, + setTargetVersion: mockSetTargetVersion, + isDowngrade: false, + setIsDowngrade: vi.fn(), + }, + hasNewVersion: true, + isAutoUpgradeEnabled: true, + isFromGitHub: false, + isFromMarketplace: true, + }), + usePluginOperations: () => ({ + handleUpdate: mockHandleUpdate, + handleUpdatedFromMarketplace: mockHandleUpdatedFromMarketplace, + handleDelete: mockHandleDelete, + }), +})) + +const createDetail = (overrides: Partial = {}): PluginDetail => ({ + id: 'plugin-1', + created_at: '2024-01-01', + updated_at: '2024-01-02', + name: 'tool-plugin', + plugin_id: 'tool-plugin', + plugin_unique_identifier: 'tool-plugin@1.0.0', + declaration: { + author: 'acme', + category: PluginCategoryEnum.tool, + name: 'provider-a', + label: { en_US: 'Tool Plugin' }, + description: { en_US: 'Tool plugin description' }, + icon: 'icon.png', + icon_dark: 'icon-dark.png', + verified: true, + tool: { + identity: { + name: 'provider-a', + }, + }, + } as PluginDetail['declaration'], + installation_id: 'install-1', + tenant_id: 'tenant-1', + endpoints_setups: 0, + endpoints_active: 0, + version: '1.0.0', + latest_version: '2.0.0', + latest_unique_identifier: 'uid-2', + source: PluginSource.marketplace, + status: 'active', + deprecated_reason: 'Deprecated', + alternative_plugin_id: 'plugin-2', + meta: undefined, + ...overrides, +}) as PluginDetail + +describe('DetailHeader', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('renders the plugin summary, source badge, auth section, and modal container', () => { + render() + + expect(screen.getByTestId('title')).toHaveTextContent('Tool Plugin') + expect(screen.getByTestId('description')).toHaveTextContent('Tool plugin description') + expect(screen.getByTestId('source-badge')).toHaveTextContent('marketplace') + expect(screen.getByTestId('plugin-auth')).toHaveTextContent('tool-plugin/provider-a') + expect(screen.getByTestId('operation-dropdown')).toHaveTextContent('https://marketplace.example.com/plugins/acme/provider-a') + expect(screen.getByTestId('header-modals')).toBeInTheDocument() + }) + + it('wires version selection, latest update, and hide actions', () => { + const onHide = vi.fn() + render() + + fireEvent.click(screen.getByTestId('version-select')) + fireEvent.click(screen.getByText('plugin.detailPanel.operation.update')) + fireEvent.click(screen.getByTestId('close-button')) + + expect(mockSetTargetVersion).toHaveBeenCalledWith({ + version: '2.0.0', + unique_identifier: 'uid-2', + isDowngrade: true, + }) + expect(mockHandleUpdate).toHaveBeenCalledTimes(2) + expect(mockHandleUpdate).toHaveBeenNthCalledWith(1, true) + expect(onHide).toHaveBeenCalled() + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/components/__tests__/index.spec.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/components/__tests__/index.spec.ts new file mode 100644 index 0000000000..a932907f44 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/components/__tests__/index.spec.ts @@ -0,0 +1,9 @@ +import { describe, expect, it } from 'vitest' +import { HeaderModals, PluginSourceBadge } from '../index' + +describe('detail-header components index', () => { + it('re-exports header modal components', () => { + expect(HeaderModals).toBeDefined() + expect(PluginSourceBadge).toBeDefined() + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/index.spec.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/index.spec.ts new file mode 100644 index 0000000000..0edda1b86a --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/index.spec.ts @@ -0,0 +1,9 @@ +import { describe, expect, it } from 'vitest' +import { useDetailHeaderState, usePluginOperations } from '../index' + +describe('detail-header hooks index', () => { + it('re-exports hook entrypoints', () => { + expect(useDetailHeaderState).toBeTypeOf('function') + expect(usePluginOperations).toBeTypeOf('function') + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/components/__tests__/modal-steps.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/components/__tests__/modal-steps.spec.tsx new file mode 100644 index 0000000000..b5e2be7105 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/components/__tests__/modal-steps.spec.tsx @@ -0,0 +1,112 @@ +import type { FormRefObject } from '@/app/components/base/form/types' +import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { FormTypeEnum } from '@/app/components/base/form/types' +import { SupportedCreationMethods } from '@/app/components/plugins/types' +import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' +import { ApiKeyStep } from '../../hooks/use-common-modal-state' +import { + ConfigurationStepContent, + MultiSteps, + VerifyStepContent, +} from '../modal-steps' + +const mockBaseForm = vi.fn() +vi.mock('@/app/components/base/form/components/base', () => ({ + BaseForm: ({ + formSchemas, + onChange, + }: { + formSchemas: Array<{ name: string }> + onChange?: () => void + }) => { + mockBaseForm(formSchemas) + return ( +
+ {formSchemas.map(schema => ( + + ))} +
+ ) + }, +})) + +vi.mock('../../../log-viewer', () => ({ + default: ({ logs }: { logs: Array<{ id: string, message: string }> }) => ( +
+ {logs.map(log => {log.message})} +
+ ), +})) + +const subscriptionBuilder: TriggerSubscriptionBuilder = { + id: 'builder-1', + name: 'builder', + provider: 'provider-a', + credential_type: TriggerCredentialTypeEnum.ApiKey, + credentials: {}, + endpoint: 'https://example.com/callback', + parameters: {}, + properties: {}, + workflows_in_use: 0, +} + +const formRef = { current: null } as React.RefObject + +describe('modal-steps', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render the api key multi step indicator', () => { + render() + + expect(screen.getByText('pluginTrigger.modal.steps.verify')).toBeInTheDocument() + expect(screen.getByText('pluginTrigger.modal.steps.configuration')).toBeInTheDocument() + }) + + it('should render verify step content and forward change events', () => { + const onChange = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByTestId('field-api_key')) + + expect(onChange).toHaveBeenCalled() + }) + + it('should render manual configuration content with logs', () => { + const onManualPropertiesChange = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByTestId('field-webhook_url')) + + expect(onManualPropertiesChange).toHaveBeenCalled() + expect(screen.getByTestId('log-viewer')).toHaveTextContent('log-entry') + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-common-modal-state.helpers.spec.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-common-modal-state.helpers.spec.ts new file mode 100644 index 0000000000..61482e2912 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-common-modal-state.helpers.spec.ts @@ -0,0 +1,196 @@ +import type { RefObject } from 'react' +import type { FormRefObject } from '@/app/components/base/form/types' +import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' +import { renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { SupportedCreationMethods } from '@/app/components/plugins/types' +import { + buildSubscriptionPayload, + DEFAULT_FORM_VALUES, + getConfirmButtonText, + getFirstFieldName, + getFormValues, + toSchemaWithTooltip, + useInitializeSubscriptionBuilder, + useSyncSubscriptionEndpoint, +} from '../use-common-modal-state.helpers' + +type BuilderResponse = { + subscription_builder: TriggerSubscriptionBuilder +} + +const { + mockToastError, + mockIsPrivateOrLocalAddress, +} = vi.hoisted(() => ({ + mockToastError: vi.fn(), + mockIsPrivateOrLocalAddress: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, + }, +})) + +vi.mock('@/utils/urlValidation', () => ({ + isPrivateOrLocalAddress: (value: string) => mockIsPrivateOrLocalAddress(value), +})) + +describe('use-common-modal-state helpers', () => { + beforeEach(() => { + vi.clearAllMocks() + mockIsPrivateOrLocalAddress.mockReturnValue(false) + }) + + it('returns default form values when the form ref is empty', () => { + expect(getFormValues({ current: null })).toEqual(DEFAULT_FORM_VALUES) + }) + + it('returns form values from the form ref when available', () => { + expect(getFormValues({ + current: { + getFormValues: () => ({ values: { subscription_name: 'Sub' }, isCheckValidated: true }), + }, + } as unknown as React.RefObject)).toEqual({ + values: { subscription_name: 'Sub' }, + isCheckValidated: true, + }) + }) + + it('derives the first field name from values or schema fallback', () => { + expect(getFirstFieldName({ callback_url: 'https://example.com' }, [{ name: 'fallback' }])).toBe('callback_url') + expect(getFirstFieldName({}, [{ name: 'fallback' }])).toBe('fallback') + expect(getFirstFieldName({}, [])).toBe('') + }) + + it('copies schema help into tooltip fields', () => { + expect(toSchemaWithTooltip([{ name: 'field', help: 'Help text' }])).toEqual([ + { + name: 'field', + help: 'Help text', + tooltip: 'Help text', + }, + ]) + }) + + it('builds subscription payloads for automatic and manual creation', () => { + expect(buildSubscriptionPayload({ + provider: 'provider-a', + subscriptionBuilderId: 'builder-a', + createType: SupportedCreationMethods.APIKEY, + subscriptionFormValues: { values: { subscription_name: 'My Sub' }, isCheckValidated: true }, + autoCommonParametersSchemaLength: 1, + autoCommonParametersFormValues: { values: { api_key: '123' }, isCheckValidated: true }, + manualPropertiesSchemaLength: 0, + manualPropertiesFormValues: undefined, + })).toEqual({ + provider: 'provider-a', + subscriptionBuilderId: 'builder-a', + name: 'My Sub', + parameters: { api_key: '123' }, + }) + + expect(buildSubscriptionPayload({ + provider: 'provider-a', + subscriptionBuilderId: 'builder-a', + createType: SupportedCreationMethods.MANUAL, + subscriptionFormValues: { values: { subscription_name: 'Manual Sub' }, isCheckValidated: true }, + autoCommonParametersSchemaLength: 0, + autoCommonParametersFormValues: undefined, + manualPropertiesSchemaLength: 1, + manualPropertiesFormValues: { values: { custom: 'value' }, isCheckValidated: true }, + })).toEqual({ + provider: 'provider-a', + subscriptionBuilderId: 'builder-a', + name: 'Manual Sub', + }) + }) + + it('returns null when required validation is missing', () => { + expect(buildSubscriptionPayload({ + provider: 'provider-a', + subscriptionBuilderId: 'builder-a', + createType: SupportedCreationMethods.APIKEY, + subscriptionFormValues: { values: {}, isCheckValidated: false }, + autoCommonParametersSchemaLength: 1, + autoCommonParametersFormValues: { values: {}, isCheckValidated: true }, + manualPropertiesSchemaLength: 0, + manualPropertiesFormValues: undefined, + })).toBeNull() + }) + + it('builds confirm button text for verify and create states', () => { + const t = (key: string, options?: Record) => `${options?.ns}.${key}` + + expect(getConfirmButtonText({ + isVerifyStep: true, + isVerifyingCredentials: false, + isBuilding: false, + t, + })).toBe('pluginTrigger.modal.common.verify') + + expect(getConfirmButtonText({ + isVerifyStep: false, + isVerifyingCredentials: false, + isBuilding: true, + t, + })).toBe('pluginTrigger.modal.common.creating') + }) + + it('initializes the subscription builder once when provider is available', async () => { + const createBuilder = vi.fn(async () => ({ + subscription_builder: { id: 'builder-1' }, + })) as unknown as (params: { + provider: string + credential_type: string + }) => Promise + const setSubscriptionBuilder = vi.fn() + + renderHook(() => useInitializeSubscriptionBuilder({ + createBuilder, + credentialType: 'oauth', + provider: 'provider-a', + subscriptionBuilder: undefined, + setSubscriptionBuilder, + t: (key: string, options?: Record) => `${options?.ns}.${key}`, + })) + + await waitFor(() => { + expect(createBuilder).toHaveBeenCalledWith({ + provider: 'provider-a', + credential_type: 'oauth', + }) + expect(setSubscriptionBuilder).toHaveBeenCalledWith({ id: 'builder-1' }) + }) + }) + + it('syncs callback endpoint and warnings into the subscription form', async () => { + mockIsPrivateOrLocalAddress.mockReturnValue(true) + const setFieldValue = vi.fn() + const setFields = vi.fn() + const subscriptionFormRef = { + current: { + getForm: () => ({ + setFieldValue, + }), + setFields, + }, + } as unknown as RefObject + + renderHook(() => useSyncSubscriptionEndpoint({ + endpoint: 'http://127.0.0.1/callback', + isConfigurationStep: true, + subscriptionFormRef, + t: (key: string, options?: Record) => `${options?.ns}.${key}`, + })) + + await waitFor(() => { + expect(setFieldValue).toHaveBeenCalledWith('callback_url', 'http://127.0.0.1/callback') + expect(setFields).toHaveBeenCalledWith([{ + name: 'callback_url', + warnings: ['pluginTrigger.modal.form.callbackUrl.privateAddressWarning'], + }]) + }) + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-common-modal-state.spec.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-common-modal-state.spec.ts new file mode 100644 index 0000000000..399d3ba60c --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-common-modal-state.spec.ts @@ -0,0 +1,253 @@ +import type { FormRefObject } from '@/app/components/base/form/types' +import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' +import { act, renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { SupportedCreationMethods } from '@/app/components/plugins/types' +import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' +import { ApiKeyStep, useCommonModalState } from '../use-common-modal-state' + +type MockPluginDetail = { + plugin_id: string + provider: string + name: string + declaration: { + trigger: { + subscription_schema: Array<{ name: string, type: string, description?: string }> + subscription_constructor: { + credentials_schema: Array<{ name: string, type: string, help?: string }> + parameters: Array<{ name: string, type: string }> + } + } + } +} + +const createMockBuilder = (overrides: Partial = {}): TriggerSubscriptionBuilder => ({ + id: 'builder-1', + name: 'builder', + provider: 'provider-a', + credential_type: TriggerCredentialTypeEnum.ApiKey, + credentials: {}, + endpoint: 'https://example.com/callback', + parameters: {}, + properties: {}, + workflows_in_use: 0, + ...overrides, +}) + +const mockDetail: MockPluginDetail = { + plugin_id: 'plugin-id', + provider: 'provider-a', + name: 'Plugin A', + declaration: { + trigger: { + subscription_schema: [{ name: 'webhook_url', type: 'string', description: 'Webhook URL' }], + subscription_constructor: { + credentials_schema: [{ name: 'api_key', type: 'string', help: 'API key help' }], + parameters: [{ name: 'repo_name', type: 'string' }], + }, + }, + }, +} + +const mockUsePluginStore = vi.fn(() => mockDetail) +vi.mock('../../../../store', () => ({ + usePluginStore: () => mockUsePluginStore(), +})) + +const mockRefetch = vi.fn() +vi.mock('../../../use-subscription-list', () => ({ + useSubscriptionList: () => ({ refetch: mockRefetch }), +})) + +const mockVerifyCredentials = vi.fn() +const mockCreateBuilder = vi.fn() +const mockBuildSubscription = vi.fn() +const mockUpdateBuilder = vi.fn() +let mockIsVerifyingCredentials = false +let mockIsBuilding = false + +vi.mock('@/service/use-triggers', () => ({ + useVerifyAndUpdateTriggerSubscriptionBuilder: () => ({ + mutate: mockVerifyCredentials, + get isPending() { return mockIsVerifyingCredentials }, + }), + useCreateTriggerSubscriptionBuilder: () => ({ + mutateAsync: mockCreateBuilder, + }), + useBuildTriggerSubscription: () => ({ + mutate: mockBuildSubscription, + get isPending() { return mockIsBuilding }, + }), + useUpdateTriggerSubscriptionBuilder: () => ({ + mutate: mockUpdateBuilder, + }), + useTriggerSubscriptionBuilderLogs: () => ({ + data: { logs: [] }, + }), +})) + +const mockToastNotify = vi.fn() +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + }, +})) + +const mockParsePluginErrorMessage = vi.fn().mockResolvedValue(null) +vi.mock('@/utils/error-parser', () => ({ + parsePluginErrorMessage: (...args: unknown[]) => mockParsePluginErrorMessage(...args), +})) + +vi.mock('@/utils/urlValidation', () => ({ + isPrivateOrLocalAddress: vi.fn().mockReturnValue(false), +})) + +const createFormRef = ({ + values = {}, + isCheckValidated = true, +}: { + values?: Record + isCheckValidated?: boolean +} = {}): FormRefObject => ({ + getFormValues: vi.fn().mockReturnValue({ values, isCheckValidated }), + setFields: vi.fn(), + getForm: vi.fn().mockReturnValue({ + setFieldValue: vi.fn(), + }), +} as unknown as FormRefObject) + +describe('useCommonModalState', () => { + beforeEach(() => { + vi.clearAllMocks() + mockIsVerifyingCredentials = false + mockIsBuilding = false + mockCreateBuilder.mockResolvedValue({ + subscription_builder: createMockBuilder(), + }) + }) + + it('should initialize api key builders and expose verify step state', async () => { + const { result } = renderHook(() => useCommonModalState({ + createType: SupportedCreationMethods.APIKEY, + onClose: vi.fn(), + })) + + await waitFor(() => { + expect(result.current.subscriptionBuilder?.id).toBe('builder-1') + }) + + expect(mockCreateBuilder).toHaveBeenCalledWith({ + provider: 'provider-a', + credential_type: TriggerCredentialTypeEnum.ApiKey, + }) + expect(result.current.currentStep).toBe(ApiKeyStep.Verify) + expect(result.current.apiKeyCredentialsSchema[0]).toMatchObject({ + name: 'api_key', + tooltip: 'API key help', + }) + }) + + it('should verify credentials and advance to configuration step', async () => { + mockVerifyCredentials.mockImplementation((_payload, options) => { + options?.onSuccess?.() + }) + + const builder = createMockBuilder() + const { result } = renderHook(() => useCommonModalState({ + createType: SupportedCreationMethods.APIKEY, + builder, + onClose: vi.fn(), + })) + + const credentialsFormRef = result.current.formRefs.apiKeyCredentialsFormRef as { current: FormRefObject | null } + credentialsFormRef.current = createFormRef({ + values: { api_key: 'secret' }, + }) + + act(() => { + result.current.handleVerify() + }) + + expect(mockVerifyCredentials).toHaveBeenCalledWith({ + provider: 'provider-a', + subscriptionBuilderId: builder.id, + credentials: { api_key: 'secret' }, + }, expect.objectContaining({ + onSuccess: expect.any(Function), + onError: expect.any(Function), + })) + expect(result.current.currentStep).toBe(ApiKeyStep.Configuration) + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + }) + + it('should build subscriptions with validated automatic parameters', () => { + const onClose = vi.fn() + const builder = createMockBuilder() + const { result } = renderHook(() => useCommonModalState({ + createType: SupportedCreationMethods.APIKEY, + builder, + onClose, + })) + + const subscriptionFormRef = result.current.formRefs.subscriptionFormRef as { current: FormRefObject | null } + const autoParamsFormRef = result.current.formRefs.autoCommonParametersFormRef as { current: FormRefObject | null } + + subscriptionFormRef.current = createFormRef({ + values: { subscription_name: 'Subscription A' }, + }) + autoParamsFormRef.current = createFormRef({ + values: { repo_name: 'repo-a' }, + }) + + act(() => { + result.current.handleCreate() + }) + + expect(mockBuildSubscription).toHaveBeenCalledWith({ + provider: 'provider-a', + subscriptionBuilderId: builder.id, + name: 'Subscription A', + parameters: { repo_name: 'repo-a' }, + }, expect.objectContaining({ + onSuccess: expect.any(Function), + onError: expect.any(Function), + })) + }) + + it('should debounce manual property updates', async () => { + vi.useFakeTimers() + + const builder = createMockBuilder({ + credential_type: TriggerCredentialTypeEnum.Unauthorized, + }) + const { result } = renderHook(() => useCommonModalState({ + createType: SupportedCreationMethods.MANUAL, + builder, + onClose: vi.fn(), + })) + + const manualFormRef = result.current.formRefs.manualPropertiesFormRef as { current: FormRefObject | null } + manualFormRef.current = createFormRef({ + values: { webhook_url: 'https://hook.example.com' }, + isCheckValidated: true, + }) + + act(() => { + result.current.handleManualPropertiesChange() + vi.advanceTimersByTime(500) + }) + + expect(mockUpdateBuilder).toHaveBeenCalledWith({ + provider: 'provider-a', + subscriptionBuilderId: builder.id, + properties: { webhook_url: 'https://hook.example.com' }, + }, expect.objectContaining({ + onError: expect.any(Function), + })) + + vi.useRealTimers() + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.helpers.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.helpers.ts new file mode 100644 index 0000000000..8df864c4fa --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.helpers.ts @@ -0,0 +1,180 @@ +'use client' +import type { Dispatch, SetStateAction } from 'react' +import type { FormRefObject } from '@/app/components/base/form/types' +import type { TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' +import type { BuildTriggerSubscriptionPayload } from '@/service/use-triggers' +import { useEffect, useRef } from 'react' +import { toast } from '@/app/components/base/ui/toast' +import { SupportedCreationMethods } from '@/app/components/plugins/types' +import { isPrivateOrLocalAddress } from '@/utils/urlValidation' + +type FormValuesResult = { + values: Record + isCheckValidated: boolean +} + +type InitializeBuilderParams = { + createBuilder: (params: { + provider: string + credential_type: string + }) => Promise<{ subscription_builder: TriggerSubscriptionBuilder }> + credentialType: string + provider?: string + subscriptionBuilder?: TriggerSubscriptionBuilder + setSubscriptionBuilder: Dispatch> + t: (key: string, options?: Record) => string +} + +type SyncEndpointParams = { + endpoint?: string + isConfigurationStep: boolean + subscriptionFormRef: React.RefObject + t: (key: string, options?: Record) => string +} + +type BuildPayloadParams = { + provider: string + subscriptionBuilderId: string + createType: SupportedCreationMethods + subscriptionFormValues?: FormValuesResult + autoCommonParametersSchemaLength: number + autoCommonParametersFormValues?: FormValuesResult + manualPropertiesSchemaLength: number + manualPropertiesFormValues?: FormValuesResult +} + +export const DEFAULT_FORM_VALUES: FormValuesResult = { values: {}, isCheckValidated: false } + +export const getFormValues = (formRef: React.RefObject) => { + return formRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES +} + +export const getFirstFieldName = ( + values: Record, + fallbackSchema: Array<{ name: string }>, +) => { + return Object.keys(values)[0] || fallbackSchema[0]?.name || '' +} + +export const toSchemaWithTooltip = (schemas: T[] = []) => { + return schemas.map(schema => ({ + ...schema, + tooltip: schema.help, + })) +} + +export const buildSubscriptionPayload = ({ + provider, + subscriptionBuilderId, + createType, + subscriptionFormValues, + autoCommonParametersSchemaLength, + autoCommonParametersFormValues, + manualPropertiesSchemaLength, + manualPropertiesFormValues, +}: BuildPayloadParams): BuildTriggerSubscriptionPayload | null => { + if (!subscriptionFormValues?.isCheckValidated) + return null + + const subscriptionNameValue = subscriptionFormValues.values.subscription_name as string + + const params: BuildTriggerSubscriptionPayload = { + provider, + subscriptionBuilderId, + name: subscriptionNameValue, + } + + if (createType !== SupportedCreationMethods.MANUAL) { + if (!autoCommonParametersSchemaLength) + return params + + if (!autoCommonParametersFormValues?.isCheckValidated) + return null + + params.parameters = autoCommonParametersFormValues.values + return params + } + + if (manualPropertiesSchemaLength && !manualPropertiesFormValues?.isCheckValidated) + return null + + return params +} + +export const getConfirmButtonText = ({ + isVerifyStep, + isVerifyingCredentials, + isBuilding, + t, +}: { + isVerifyStep: boolean + isVerifyingCredentials: boolean + isBuilding: boolean + t: (key: string, options?: Record) => string +}) => { + if (isVerifyStep) { + return isVerifyingCredentials + ? t('modal.common.verifying', { ns: 'pluginTrigger' }) + : t('modal.common.verify', { ns: 'pluginTrigger' }) + } + + return isBuilding + ? t('modal.common.creating', { ns: 'pluginTrigger' }) + : t('modal.common.create', { ns: 'pluginTrigger' }) +} + +export const useInitializeSubscriptionBuilder = ({ + createBuilder, + credentialType, + provider, + subscriptionBuilder, + setSubscriptionBuilder, + t, +}: InitializeBuilderParams) => { + const isInitializedRef = useRef(false) + + useEffect(() => { + const initializeBuilder = async () => { + isInitializedRef.current = true + try { + const response = await createBuilder({ + provider: provider || '', + credential_type: credentialType, + }) + setSubscriptionBuilder(response.subscription_builder) + } + catch (error) { + console.error('createBuilder error:', error) + toast.error(t('modal.errors.createFailed', { ns: 'pluginTrigger' })) + } + } + + if (!isInitializedRef.current && !subscriptionBuilder && provider) + initializeBuilder() + }, [subscriptionBuilder, provider, credentialType, createBuilder, setSubscriptionBuilder, t]) +} + +export const useSyncSubscriptionEndpoint = ({ + endpoint, + isConfigurationStep, + subscriptionFormRef, + t, +}: SyncEndpointParams) => { + useEffect(() => { + if (!endpoint || !subscriptionFormRef.current || !isConfigurationStep) + return + + const form = subscriptionFormRef.current.getForm() + if (form) + form.setFieldValue('callback_url', endpoint) + + const warnings = isPrivateOrLocalAddress(endpoint) + ? [t('modal.form.callbackUrl.privateAddressWarning', { ns: 'pluginTrigger' })] + : [] + + subscriptionFormRef.current.setFields([{ + name: 'callback_url', + warnings, + }]) + }, [endpoint, isConfigurationStep, subscriptionFormRef, t]) +} diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts index 339f782b45..e55f9525fe 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts @@ -3,7 +3,6 @@ import type { SimpleDetail } from '../../../store' import type { SchemaItem } from '../components/modal-steps' import type { FormRefObject } from '@/app/components/base/form/types' import type { TriggerLogEntity, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' -import type { BuildTriggerSubscriptionPayload } from '@/service/use-triggers' import { debounce } from 'es-toolkit/compat' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -18,9 +17,17 @@ import { useVerifyAndUpdateTriggerSubscriptionBuilder, } from '@/service/use-triggers' import { parsePluginErrorMessage } from '@/utils/error-parser' -import { isPrivateOrLocalAddress } from '@/utils/urlValidation' import { usePluginStore } from '../../../store' import { useSubscriptionList } from '../../use-subscription-list' +import { + buildSubscriptionPayload, + getConfirmButtonText, + getFirstFieldName, + getFormValues, + toSchemaWithTooltip, + useInitializeSubscriptionBuilder, + useSyncSubscriptionEndpoint, +} from './use-common-modal-state.helpers' // ============================================================================ // Types @@ -85,8 +92,6 @@ type UseCommonModalStateReturn = { handleApiKeyCredentialsChange: () => void } -const DEFAULT_FORM_VALUES = { values: {}, isCheckValidated: false } - // ============================================================================ // Hook Implementation // ============================================================================ @@ -105,7 +110,6 @@ export const useCommonModalState = ({ createType === SupportedCreationMethods.APIKEY ? ApiKeyStep.Verify : ApiKeyStep.Configuration, ) const [subscriptionBuilder, setSubscriptionBuilder] = useState(builder) - const isInitializedRef = useRef(false) // Form refs const manualPropertiesFormRef = useRef(null) @@ -123,12 +127,9 @@ export const useCommonModalState = ({ const manualPropertiesSchema = detail?.declaration?.trigger?.subscription_schema || [] const autoCommonParametersSchema = detail?.declaration.trigger?.subscription_constructor?.parameters || [] - const apiKeyCredentialsSchema = useMemo(() => { + const apiKeyCredentialsSchema = useMemo(() => { const rawSchema = detail?.declaration?.trigger?.subscription_constructor?.credentials_schema || [] - return rawSchema.map(schema => ({ - ...schema, - tooltip: schema.help, - })) + return toSchemaWithTooltip(rawSchema) as SchemaItem[] }, [detail?.declaration?.trigger?.subscription_constructor?.credentials_schema]) // Log data for manual mode @@ -162,25 +163,14 @@ export const useCommonModalState = ({ [updateBuilder, t], ) - // Initialize builder - useEffect(() => { - const initializeBuilder = async () => { - isInitializedRef.current = true - try { - const response = await createBuilder({ - provider: detail?.provider || '', - credential_type: CREDENTIAL_TYPE_MAP[createType], - }) - setSubscriptionBuilder(response.subscription_builder) - } - catch (error) { - console.error('createBuilder error:', error) - toast.error(t('modal.errors.createFailed', { ns: 'pluginTrigger' })) - } - } - if (!isInitializedRef.current && !subscriptionBuilder && detail?.provider) - initializeBuilder() - }, [subscriptionBuilder, detail?.provider, createType, createBuilder, t]) + useInitializeSubscriptionBuilder({ + createBuilder, + credentialType: CREDENTIAL_TYPE_MAP[createType], + provider: detail?.provider, + subscriptionBuilder, + setSubscriptionBuilder, + t, + }) // Cleanup debounced function useEffect(() => { @@ -189,24 +179,12 @@ export const useCommonModalState = ({ } }, [debouncedUpdate]) - // Update endpoint in form when endpoint changes - useEffect(() => { - if (!subscriptionBuilder?.endpoint || !subscriptionFormRef.current || currentStep !== ApiKeyStep.Configuration) - return - - const form = subscriptionFormRef.current.getForm() - if (form) - form.setFieldValue('callback_url', subscriptionBuilder.endpoint) - - const warnings = isPrivateOrLocalAddress(subscriptionBuilder.endpoint) - ? [t('modal.form.callbackUrl.privateAddressWarning', { ns: 'pluginTrigger' })] - : [] - - subscriptionFormRef.current?.setFields([{ - name: 'callback_url', - warnings, - }]) - }, [subscriptionBuilder?.endpoint, currentStep, t]) + useSyncSubscriptionEndpoint({ + endpoint: subscriptionBuilder?.endpoint, + isConfigurationStep: currentStep === ApiKeyStep.Configuration, + subscriptionFormRef, + t, + }) // Handle manual properties change const handleManualPropertiesChange = useCallback(() => { @@ -237,7 +215,7 @@ export const useCommonModalState = ({ return } - const apiKeyCredentialsFormValues = apiKeyCredentialsFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES + const apiKeyCredentialsFormValues = getFormValues(apiKeyCredentialsFormRef) const credentials = apiKeyCredentialsFormValues.values if (!Object.keys(credentials).length) { @@ -245,8 +223,10 @@ export const useCommonModalState = ({ return } + const credentialFieldName = getFirstFieldName(credentials, apiKeyCredentialsSchema) + apiKeyCredentialsFormRef.current?.setFields([{ - name: Object.keys(credentials)[0], + name: credentialFieldName, errors: [], }]) @@ -264,13 +244,13 @@ export const useCommonModalState = ({ onError: async (error: unknown) => { const errorMessage = await parsePluginErrorMessage(error) || t('modal.apiKey.verify.error', { ns: 'pluginTrigger' }) apiKeyCredentialsFormRef.current?.setFields([{ - name: Object.keys(credentials)[0], + name: credentialFieldName, errors: [errorMessage], }]) }, }, ) - }, [detail?.provider, subscriptionBuilder?.id, verifyCredentials, t]) + }, [apiKeyCredentialsSchema, detail?.provider, subscriptionBuilder?.id, verifyCredentials, t]) // Handle create const handleCreate = useCallback(() => { @@ -279,31 +259,19 @@ export const useCommonModalState = ({ return } - const subscriptionFormValues = subscriptionFormRef.current?.getFormValues({}) - if (!subscriptionFormValues?.isCheckValidated) - return - - const subscriptionNameValue = subscriptionFormValues?.values?.subscription_name as string - - const params: BuildTriggerSubscriptionPayload = { + const params = buildSubscriptionPayload({ provider: detail?.provider || '', subscriptionBuilderId: subscriptionBuilder.id, - name: subscriptionNameValue, - } + createType, + subscriptionFormValues: getFormValues(subscriptionFormRef), + autoCommonParametersSchemaLength: autoCommonParametersSchema.length, + autoCommonParametersFormValues: getFormValues(autoCommonParametersFormRef), + manualPropertiesSchemaLength: manualPropertiesSchema.length, + manualPropertiesFormValues: getFormValues(manualPropertiesFormRef), + }) - if (createType !== SupportedCreationMethods.MANUAL) { - if (autoCommonParametersSchema.length > 0) { - const autoCommonParametersFormValues = autoCommonParametersFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES - if (!autoCommonParametersFormValues?.isCheckValidated) - return - params.parameters = autoCommonParametersFormValues.values - } - } - else if (manualPropertiesSchema.length > 0) { - const manualFormValues = manualPropertiesFormRef.current?.getFormValues({}) || DEFAULT_FORM_VALUES - if (!manualFormValues?.isCheckValidated) - return - } + if (!params) + return buildSubscription( params, @@ -341,14 +309,12 @@ export const useCommonModalState = ({ // Confirm button text const confirmButtonText = useMemo(() => { - if (currentStep === ApiKeyStep.Verify) { - return isVerifyingCredentials - ? t('modal.common.verifying', { ns: 'pluginTrigger' }) - : t('modal.common.verify', { ns: 'pluginTrigger' }) - } - return isBuilding - ? t('modal.common.creating', { ns: 'pluginTrigger' }) - : t('modal.common.create', { ns: 'pluginTrigger' }) + return getConfirmButtonText({ + isVerifyStep: currentStep === ApiKeyStep.Verify, + isVerifyingCredentials, + isBuilding, + t, + }) }, [currentStep, isVerifyingCredentials, isBuilding, t]) return { diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/index.spec.ts b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/index.spec.ts new file mode 100644 index 0000000000..3ff43c4fb6 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/index.spec.ts @@ -0,0 +1,22 @@ +import { describe, expect, it } from 'vitest' +import { + SchemaModal, + ToolAuthorizationSection, + ToolBaseForm, + ToolCredentialsForm, + ToolItem, + ToolSettingsPanel, + ToolTrigger, +} from '../index' + +describe('tool-selector components index', () => { + it('re-exports the tool selector components', () => { + expect(SchemaModal).toBeDefined() + expect(ToolAuthorizationSection).toBeDefined() + expect(ToolBaseForm).toBeDefined() + expect(ToolCredentialsForm).toBeDefined() + expect(ToolItem).toBeDefined() + expect(ToolSettingsPanel).toBeDefined() + expect(ToolTrigger).toBeDefined() + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/reasoning-config-form.helpers.spec.ts b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/reasoning-config-form.helpers.spec.ts new file mode 100644 index 0000000000..24d7fd036d --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/reasoning-config-form.helpers.spec.ts @@ -0,0 +1,181 @@ +import { describe, expect, it } from 'vitest' +import { FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types' +import { VarType } from '@/app/components/workflow/types' +import { + createEmptyAppValue, + createFilterVar, + createPickerProps, + createReasoningFormContext, + getFieldFlags, + getFieldTitle, + getVarKindType, + getVisibleSelectOptions, + mergeReasoningValue, + resolveTargetVarType, + updateInputAutoState, + updateReasoningValue, + updateVariableSelectorValue, + updateVariableTypeValue, +} from '../reasoning-config-form.helpers' + +describe('reasoning-config-form helpers', () => { + it('maps schema types to variable-kind types and target variable types', () => { + expect(getVarKindType(FormTypeEnum.files)).toBe(VarKindType.variable) + expect(getVarKindType(FormTypeEnum.textNumber)).toBe(VarKindType.constant) + expect(getVarKindType(FormTypeEnum.textInput)).toBe(VarKindType.mixed) + expect(getVarKindType(FormTypeEnum.dynamicSelect)).toBeUndefined() + + expect(resolveTargetVarType(FormTypeEnum.textInput)).toBe(VarType.string) + expect(resolveTargetVarType(FormTypeEnum.textNumber)).toBe(VarType.number) + expect(resolveTargetVarType(FormTypeEnum.files)).toBe(VarType.arrayFile) + expect(resolveTargetVarType(FormTypeEnum.file)).toBe(VarType.file) + expect(resolveTargetVarType(FormTypeEnum.checkbox)).toBe(VarType.boolean) + expect(resolveTargetVarType(FormTypeEnum.object)).toBe(VarType.object) + expect(resolveTargetVarType(FormTypeEnum.array)).toBe(VarType.arrayObject) + }) + + it('creates variable filters for supported field types', () => { + const numberFilter = createFilterVar(FormTypeEnum.textNumber) + const stringFilter = createFilterVar(FormTypeEnum.textInput) + const fileFilter = createFilterVar(FormTypeEnum.files) + + expect(numberFilter?.({ type: VarType.number } as never)).toBe(true) + expect(numberFilter?.({ type: VarType.string } as never)).toBe(false) + expect(stringFilter?.({ type: VarType.secret } as never)).toBe(true) + expect(fileFilter?.({ type: VarType.arrayFile } as never)).toBe(true) + }) + + it('filters select options based on show_on conditions', () => { + const options = [ + { + value: 'one', + label: { en_US: 'One', zh_Hans: 'One' }, + show_on: [], + }, + { + value: 'two', + label: { en_US: 'Two', zh_Hans: 'Two' }, + show_on: [{ variable: 'mode', value: 'advanced' }], + }, + ] + + expect(getVisibleSelectOptions(options as never, { + mode: { value: { value: 'advanced' } }, + }, 'en_US')).toEqual([ + { value: 'one', name: 'One' }, + { value: 'two', name: 'Two' }, + ]) + + expect(getVisibleSelectOptions(options as never, { + mode: { value: { value: 'basic' } }, + }, 'en_US')).toEqual([ + { value: 'one', name: 'One' }, + ]) + }) + + it('updates reasoning values for auto, constant, variable, and merged states', () => { + const value = { + prompt: { + value: { + type: VarKindType.constant, + value: 'hello', + }, + auto: 0 as const, + }, + } + + expect(updateInputAutoState(value, 'prompt', true, FormTypeEnum.textInput)).toEqual({ + prompt: { + value: null, + auto: 1, + }, + }) + + expect(updateVariableTypeValue(value, 'prompt', VarKindType.variable, '')).toEqual({ + prompt: { + value: { + type: VarKindType.variable, + value: '', + }, + auto: 0, + }, + }) + + expect(updateReasoningValue(value, 'prompt', FormTypeEnum.textInput, 'updated')).toEqual({ + prompt: { + value: { + type: VarKindType.mixed, + value: 'updated', + }, + auto: 0, + }, + }) + + expect(mergeReasoningValue(value, 'prompt', { extra: true })).toEqual({ + prompt: { + value: { + type: VarKindType.constant, + value: 'hello', + extra: true, + }, + auto: 0, + }, + }) + + expect(updateVariableSelectorValue(value, 'prompt', ['node', 'field'])).toEqual({ + prompt: { + value: { + type: VarKindType.variable, + value: ['node', 'field'], + }, + auto: 0, + }, + }) + }) + + it('derives field flags and picker props from schema types', () => { + expect(getFieldFlags(FormTypeEnum.object, { type: VarKindType.constant })).toEqual(expect.objectContaining({ + isObject: true, + isShowJSONEditor: true, + showTypeSwitch: true, + isConstant: true, + })) + + expect(createPickerProps({ + type: FormTypeEnum.select, + value: {}, + language: 'en_US', + schema: { + options: [ + { + value: 'one', + label: { en_US: 'One', zh_Hans: 'One' }, + show_on: [], + }, + ], + } as never, + })).toEqual(expect.objectContaining({ + targetVarType: VarType.string, + selectItems: [{ value: 'one', name: 'One' }], + })) + }) + + it('provides label helpers and empty defaults', () => { + expect(getFieldTitle({ en_US: 'Prompt', zh_Hans: 'Prompt' }, 'en_US')).toBe('Prompt') + expect(createEmptyAppValue()).toEqual({ + app_id: '', + inputs: {}, + files: [], + }) + expect(createReasoningFormContext({ + availableNodes: [{ id: 'node-1' }] as never, + nodeId: 'node-current', + nodeOutputVars: [{ nodeId: 'node-1' }] as never, + })).toEqual({ + availableNodes: [{ id: 'node-1' }], + nodeId: 'node-current', + nodeOutputVars: [{ nodeId: 'node-1' }], + }) + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/reasoning-config-form.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/reasoning-config-form.spec.tsx new file mode 100644 index 0000000000..f64d396d07 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/reasoning-config-form.spec.tsx @@ -0,0 +1,340 @@ +import type { ToolFormSchema } from '@/app/components/tools/utils/to-form-schema' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { Type } from '@/app/components/workflow/nodes/llm/types' +import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types' +import ReasoningConfigForm from '../reasoning-config-form' + +vi.mock('@/app/components/base/input', () => ({ + default: ({ value, onChange }: { value?: string, onChange: (e: { target: { value: string } }) => void }) => ( + onChange({ target: { value: e.target.value } })} /> + ), +})) + +vi.mock('@/app/components/base/select', () => ({ + SimpleSelect: ({ + items, + onSelect, + }: { + items: Array<{ value: string, name: string }> + onSelect: (item: { value: string }) => void + }) => ( +
+ {items.map(item => ( + + ))} +
+ ), +})) + +vi.mock('@/app/components/base/switch', () => ({ + default: ({ value, onChange }: { value: boolean, onChange: (value: boolean) => void }) => ( + + ), +})) + +vi.mock('@/app/components/base/tooltip', () => ({ + default: ({ children }: { children?: React.ReactNode }) => <>{children}, +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useLanguage: () => 'en_US', +})) + +vi.mock('@/app/components/plugins/plugin-detail-panel/app-selector', () => ({ + default: ({ onSelect }: { onSelect: (value: Record) => void }) => ( + + ), +})) + +vi.mock('@/app/components/plugins/plugin-detail-panel/model-selector', () => ({ + default: ({ setModel }: { setModel: (value: Record) => void }) => ( + + ), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({ + default: ({ onChange }: { onChange: (value: string) => void }) => ( + + ), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/form-input-boolean', () => ({ + default: ({ onChange }: { onChange: (value: boolean) => void }) => ( + + ), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/form-input-type-switch', () => ({ + default: ({ onChange }: { onChange: (value: VarKindType) => void }) => ( + + ), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => ({ + default: ({ onChange }: { onChange: (value: string) => void }) => ( + + ), +})) + +vi.mock('@/app/components/workflow/nodes/tool/components/mixed-variable-text-input', () => ({ + default: ({ onChange }: { onChange: (value: string) => void }) => ( + + ), +})) + +vi.mock('../schema-modal', () => ({ + default: ({ isShow, rootName, onClose }: { isShow: boolean, rootName: string, onClose: () => void }) => ( + isShow + ? ( +
+ {rootName} + +
+ ) + : null + ), +})) + +const createSchema = (overrides: Partial = {}): ToolFormSchema => ({ + variable: 'field', + type: FormTypeEnum.textInput, + default: '', + required: false, + label: { en_US: 'Field', zh_Hans: '字段' }, + tooltip: { en_US: 'Tooltip', zh_Hans: '提示' }, + scope: 'all', + url: '', + input_schema: {}, + placeholder: { en_US: 'Placeholder', zh_Hans: '占位符' }, + options: [], + ...overrides, +} as ToolFormSchema) + +describe('ReasoningConfigForm', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should toggle automatic values for text fields', () => { + const onChange = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByTestId('auto-switch')) + + expect(onChange).toHaveBeenCalledWith({ + field: { + auto: 1, + value: null, + }, + }) + }) + + it('should update mixed text and variable types', () => { + const onChange = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByTestId('mixed-input')) + fireEvent.click(screen.getByTestId('type-switch')) + + expect(onChange).toHaveBeenNthCalledWith(1, expect.objectContaining({ + field: { + auto: 0, + value: { type: VarKindType.mixed, value: 'updated-text' }, + }, + })) + expect(onChange).toHaveBeenNthCalledWith(2, expect.objectContaining({ + count: { + auto: 0, + value: { type: VarKindType.variable, value: '' }, + }, + })) + }) + + it('should open schema modal for object fields and support app selection', () => { + const onChange = vi.fn() + + const { container } = render( + , + ) + + fireEvent.click(container.querySelector('div.ml-0\\.5.cursor-pointer')!) + expect(screen.getByTestId('schema-modal')).toHaveTextContent('Config') + fireEvent.click(screen.getByTestId('close-schema')) + + fireEvent.click(screen.getByTestId('app-selector')) + + expect(onChange).toHaveBeenCalledWith(expect.objectContaining({ + app: { + auto: 0, + value: { + type: undefined, + value: { app_id: 'app-1', inputs: { topic: 'hello' } }, + }, + }, + })) + }) + + it('should merge model selector values into the current field value', () => { + const onChange = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByTestId('model-selector')) + + expect(onChange).toHaveBeenCalledWith({ + model: { + auto: 0, + value: { + provider: 'openai', + model: 'gpt-4.1', + }, + }, + }) + }) + + it('should update file fields from the variable selector', () => { + const onChange = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByTestId('var-picker')) + + expect(onChange).toHaveBeenCalledWith({ + files: { + auto: 0, + value: { + type: VarKindType.variable, + value: ['node', 'field'], + }, + }, + }) + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/schema-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/schema-modal.spec.tsx new file mode 100644 index 0000000000..86158ab950 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/schema-modal.spec.tsx @@ -0,0 +1,61 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import SchemaModal from '../schema-modal' + +vi.mock('@/app/components/base/modal', () => ({ + default: ({ + children, + isShow, + }: { + children: React.ReactNode + isShow: boolean + }) => isShow ?
{children}
: null, +})) + +vi.mock('@/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor', () => ({ + default: ({ rootName }: { rootName: string }) =>
{rootName}
, +})) + +vi.mock('@/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/context', () => ({ + MittProvider: ({ children }: { children: React.ReactNode }) => <>{children}, + VisualEditorContextProvider: ({ children }: { children: React.ReactNode }) => <>{children}, +})) + +describe('SchemaModal', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('does not render content when hidden', () => { + render( + , + ) + + expect(screen.queryByTestId('modal')).not.toBeInTheDocument() + }) + + it('renders the schema title and closes when the close control is clicked', () => { + const onClose = vi.fn() + render( + , + ) + + expect(screen.getByText('workflow.nodes.agent.parameterSchema')).toBeInTheDocument() + expect(screen.getByTestId('visual-editor')).toHaveTextContent('response') + + const closeButton = document.body.querySelector('div.absolute.right-5.top-5') + fireEvent.click(closeButton!) + + expect(onClose).toHaveBeenCalled() + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-authorization-section.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-authorization-section.spec.tsx new file mode 100644 index 0000000000..03b684faac --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-authorization-section.spec.tsx @@ -0,0 +1,64 @@ +import type { ToolWithProvider } from '@/app/components/workflow/types' +import { render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import { CollectionType } from '@/app/components/tools/types' +import ToolAuthorizationSection from '../tool-authorization-section' + +vi.mock('@/app/components/plugins/plugin-auth', () => ({ + AuthCategory: { + tool: 'tool', + }, + PluginAuthInAgent: ({ pluginPayload, credentialId }: { + pluginPayload: { provider: string, providerType: string } + credentialId?: string + }) => ( +
+ {pluginPayload.provider} + : + {pluginPayload.providerType} + : + {credentialId} +
+ ), +})) + +const createProvider = (overrides: Partial = {}): ToolWithProvider => ({ + name: 'provider-a', + type: CollectionType.builtIn, + allow_delete: true, + ...overrides, +}) as ToolWithProvider + +describe('ToolAuthorizationSection', () => { + it('returns null for providers that are not removable built-ins', () => { + const { container, rerender } = render( + , + ) + + expect(container).toBeEmptyDOMElement() + + rerender( + , + ) + + expect(container).toBeEmptyDOMElement() + }) + + it('renders the authorization panel for removable built-in providers', () => { + render( + , + ) + + expect(screen.getByTestId('plugin-auth-in-agent')).toHaveTextContent('provider-a:builtin:credential-1') + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-item.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-item.spec.tsx new file mode 100644 index 0000000000..9a689dec8c --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-item.spec.tsx @@ -0,0 +1,130 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import ToolItem from '../tool-item' + +let mcpAllowed = true + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: mcpAllowed, + }), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-not-support-tooltip', () => ({ + default: () =>
mcp unavailable
, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/install-plugin-button', () => ({ + InstallPluginButton: ({ onSuccess }: { onSuccess: () => void }) => ( + + ), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/switch-plugin-version', () => ({ + SwitchPluginVersion: ({ onChange }: { onChange: () => void }) => ( + + ), +})) + +vi.mock('@/app/components/base/tooltip', () => ({ + default: ({ + children, + popupContent, + }: { + children: React.ReactNode + popupContent: React.ReactNode + }) => ( +
+ {children} +
{popupContent}
+
+ ), +})) + +describe('ToolItem', () => { + beforeEach(() => { + vi.clearAllMocks() + mcpAllowed = true + }) + + it('shows auth status actions for no-auth and auth-removed states', () => { + const { rerender } = render( + , + ) + + expect(screen.getByText('tools.notAuthorized')).toBeInTheDocument() + + rerender( + , + ) + + expect(screen.getByText('plugin.auth.authRemoved')).toBeInTheDocument() + }) + + it('surfaces install and version mismatch recovery actions', () => { + const onInstall = vi.fn() + const { rerender } = render( + , + ) + + fireEvent.click(screen.getByText('install plugin')) + expect(onInstall).toHaveBeenCalledTimes(1) + + rerender( + , + ) + + fireEvent.click(screen.getByText('switch version')) + expect(onInstall).toHaveBeenCalledTimes(2) + }) + + it('blocks unsupported MCP tools and still exposes error state', () => { + mcpAllowed = false + const { rerender } = render( + , + ) + + expect(screen.getByTestId('mcp-tooltip')).toBeInTheDocument() + + rerender( + , + ) + + expect(screen.getByText('tool failed')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-settings-panel.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-settings-panel.spec.tsx new file mode 100644 index 0000000000..56c98f695d --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-settings-panel.spec.tsx @@ -0,0 +1,100 @@ +import type { ToolWithProvider } from '@/app/components/workflow/types' +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import ToolSettingsPanel from '../tool-settings-panel' + +vi.mock('@/app/components/base/tab-slider-plain', () => ({ + default: ({ + options, + onChange, + }: { + options: Array<{ value: string, text: string }> + onChange: (value: string) => void + }) => ( +
+ {options.map(option => ( + + ))} +
+ ), +})) + +vi.mock('@/app/components/workflow/nodes/tool/components/tool-form', () => ({ + default: ({ schema }: { schema: Array<{ name: string }> }) =>
{schema.map(item => item.name).join(',')}
, +})) + +vi.mock('../reasoning-config-form', () => ({ + default: ({ schemas }: { schemas: Array<{ name: string }> }) =>
{schemas.map(item => item.name).join(',')}
, +})) + +const baseProps = { + nodeId: 'node-1', + currType: 'settings' as const, + settingsFormSchemas: [{ name: 'api_key' }] as never[], + paramsFormSchemas: [{ name: 'temperature' }] as never[], + settingsValue: {}, + showTabSlider: true, + userSettingsOnly: false, + reasoningConfigOnly: false, + nodeOutputVars: [], + availableNodes: [], + onCurrTypeChange: vi.fn(), + onSettingsFormChange: vi.fn(), + onParamsFormChange: vi.fn(), + currentProvider: { + is_team_authorization: true, + } as ToolWithProvider, +} + +describe('ToolSettingsPanel', () => { + it('returns null when the provider is not team-authorized or has no forms', () => { + const { container, rerender } = render( + , + ) + + expect(container).toBeEmptyDOMElement() + + rerender( + , + ) + + expect(container).toBeEmptyDOMElement() + }) + + it('renders the settings form and lets the tab slider switch to params', () => { + const onCurrTypeChange = vi.fn() + render( + , + ) + + expect(screen.getByTestId('tool-form')).toHaveTextContent('api_key') + fireEvent.click(screen.getByText('plugin.detailPanel.toolSelector.params')) + + expect(onCurrTypeChange).toHaveBeenCalledWith('params') + }) + + it('renders params tips and the reasoning config form for params-only views', () => { + render( + , + ) + + expect(screen.getAllByText('plugin.detailPanel.toolSelector.paramsTip1')).toHaveLength(2) + expect(screen.getByTestId('reasoning-config-form')).toHaveTextContent('temperature') + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-trigger.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-trigger.spec.tsx new file mode 100644 index 0000000000..903e1ef687 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-trigger.spec.tsx @@ -0,0 +1,38 @@ +import type { ToolWithProvider } from '@/app/components/workflow/types' +import { render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import ToolTrigger from '../tool-trigger' + +vi.mock('@/app/components/workflow/block-icon', () => ({ + default: () =>
, +})) + +describe('ToolTrigger', () => { + it('renders the placeholder for the unconfigured state', () => { + render() + + expect(screen.getByText('plugin.detailPanel.toolSelector.placeholder')).toBeInTheDocument() + }) + + it('renders the selected provider icon and tool label', () => { + render( + , + ) + + expect(screen.getByTestId('block-icon')).toBeInTheDocument() + expect(screen.getByText('Search Tool')).toBeInTheDocument() + }) + + it('switches to the configure placeholder when requested', () => { + render() + + expect(screen.getByText('plugin.detailPanel.configureTool')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.helpers.ts b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.helpers.ts new file mode 100644 index 0000000000..86e42aab6b --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.helpers.ts @@ -0,0 +1,233 @@ +import type { Node } from 'reactflow' +import type { CredentialFormSchema } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { ToolFormSchema } from '@/app/components/tools/utils/to-form-schema' +import type { NodeOutPutVar, ValueSelector, Var } from '@/app/components/workflow/types' +import { produce } from 'immer' +import { FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types' +import { VarType } from '@/app/components/workflow/types' + +export type ReasoningConfigInputValue = { + type?: VarKindType + value?: unknown + [key: string]: unknown +} | null + +export type ReasoningConfigInput = { + value: ReasoningConfigInputValue + auto?: 0 | 1 +} + +export type ReasoningConfigValue = Record + +export const getVarKindType = (type: string) => { + if (type === FormTypeEnum.file || type === FormTypeEnum.files) + return VarKindType.variable + + if ([FormTypeEnum.select, FormTypeEnum.checkbox, FormTypeEnum.textNumber, FormTypeEnum.array, FormTypeEnum.object].includes(type as FormTypeEnum)) + return VarKindType.constant + + if (type === FormTypeEnum.textInput || type === FormTypeEnum.secretInput) + return VarKindType.mixed + + return undefined +} + +export const resolveTargetVarType = (type: string) => { + if (type === FormTypeEnum.textInput || type === FormTypeEnum.secretInput) + return VarType.string + if (type === FormTypeEnum.textNumber) + return VarType.number + if (type === FormTypeEnum.files) + return VarType.arrayFile + if (type === FormTypeEnum.file) + return VarType.file + if (type === FormTypeEnum.checkbox) + return VarType.boolean + if (type === FormTypeEnum.object) + return VarType.object + if (type === FormTypeEnum.array) + return VarType.arrayObject + + return VarType.string +} + +export const createFilterVar = (type: string) => { + if (type === FormTypeEnum.textNumber) + return (varPayload: Var) => varPayload.type === VarType.number + + if (type === FormTypeEnum.textInput || type === FormTypeEnum.secretInput) + return (varPayload: Var) => [VarType.string, VarType.number, VarType.secret].includes(varPayload.type) + + if (type === FormTypeEnum.file || type === FormTypeEnum.files) + return (varPayload: Var) => [VarType.file, VarType.arrayFile].includes(varPayload.type) + + if (type === FormTypeEnum.checkbox) + return (varPayload: Var) => varPayload.type === VarType.boolean + + if (type === FormTypeEnum.object) + return (varPayload: Var) => varPayload.type === VarType.object + + if (type === FormTypeEnum.array) + return (varPayload: Var) => [VarType.array, VarType.arrayString, VarType.arrayNumber, VarType.arrayObject].includes(varPayload.type) + + return undefined +} + +export const getVisibleSelectOptions = ( + options: NonNullable, + value: ReasoningConfigValue, + language: string, +) => { + return options.filter((option) => { + if (option.show_on.length) + return option.show_on.every(showOnItem => value[showOnItem.variable]?.value?.value === showOnItem.value) + + return true + }).map(option => ({ + value: option.value, + name: option.label[language] || option.label.en_US, + })) +} + +export const updateInputAutoState = ( + value: ReasoningConfigValue, + variable: string, + enabled: boolean, + type: string, +) => { + return { + ...value, + [variable]: { + value: enabled ? null : { type: getVarKindType(type), value: null }, + auto: enabled ? 1 as const : 0 as const, + }, + } +} + +export const updateVariableTypeValue = ( + value: ReasoningConfigValue, + variable: string, + newType: VarKindType, + defaultValue: unknown, +) => { + return produce(value, (draft) => { + draft[variable].value = { + type: newType, + value: newType === VarKindType.variable ? '' : defaultValue, + } + }) +} + +export const updateReasoningValue = ( + value: ReasoningConfigValue, + variable: string, + type: string, + newValue: unknown, +) => { + return produce(value, (draft) => { + draft[variable].value = { + type: getVarKindType(type), + value: newValue, + } + }) +} + +export const mergeReasoningValue = ( + value: ReasoningConfigValue, + variable: string, + newValue: Record, +) => { + return produce(value, (draft) => { + const currentValue = draft[variable].value as Record | undefined + draft[variable].value = { + ...currentValue, + ...newValue, + } + }) +} + +export const updateVariableSelectorValue = ( + value: ReasoningConfigValue, + variable: string, + newValue: ValueSelector | string, +) => { + return produce(value, (draft) => { + draft[variable].value = { + type: VarKindType.variable, + value: newValue, + } + }) +} + +export const getFieldFlags = (type: string, varInput?: ReasoningConfigInputValue) => { + const isString = type === FormTypeEnum.textInput || type === FormTypeEnum.secretInput + const isNumber = type === FormTypeEnum.textNumber + const isObject = type === FormTypeEnum.object + const isArray = type === FormTypeEnum.array + const isFile = type === FormTypeEnum.file || type === FormTypeEnum.files + const isBoolean = type === FormTypeEnum.checkbox + const isSelect = type === FormTypeEnum.select + const isAppSelector = type === FormTypeEnum.appSelector + const isModelSelector = type === FormTypeEnum.modelSelector + const isConstant = varInput?.type === VarKindType.constant || !varInput?.type + + return { + isString, + isNumber, + isObject, + isArray, + isShowJSONEditor: isObject || isArray, + isFile, + isBoolean, + isSelect, + isAppSelector, + isModelSelector, + showTypeSwitch: isNumber || isObject || isArray, + isConstant, + showVariableSelector: isFile || varInput?.type === VarKindType.variable, + } +} + +export const createPickerProps = ({ + type, + value, + language, + schema, +}: { + type: string + value: ReasoningConfigValue + language: string + schema: ToolFormSchema +}) => { + return { + filterVar: createFilterVar(type), + schema: schema as Partial, + targetVarType: resolveTargetVarType(type), + selectItems: schema.options ? getVisibleSelectOptions(schema.options, value, language) : [], + } +} + +export const getFieldTitle = (labels: { [key: string]: string }, language: string) => { + return labels[language] || labels.en_US +} + +export const createEmptyAppValue = () => ({ + app_id: '', + inputs: {}, + files: [], +}) + +export const createReasoningFormContext = ({ + availableNodes, + nodeId, + nodeOutputVars, +}: { + availableNodes: Node[] + nodeId: string + nodeOutputVars: NodeOutPutVar[] +}) => ({ + availableNodes, + nodeId, + nodeOutputVars, +}) diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.tsx index 38328aa8b3..1edc147370 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.tsx @@ -1,19 +1,16 @@ import type { Node } from 'reactflow' -import type { CredentialFormSchema } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { ReasoningConfigValue as ReasoningConfigValueShape } from './reasoning-config-form.helpers' import type { ToolFormSchema } from '@/app/components/tools/utils/to-form-schema' import type { SchemaRoot } from '@/app/components/workflow/nodes/llm/types' -import type { ToolVarInputs } from '@/app/components/workflow/nodes/tool/types' import type { NodeOutPutVar, ValueSelector, - Var, } from '@/app/components/workflow/types' import { RiArrowRightUpLine, RiBracesLine, } from '@remixicon/react' import { useBoolean } from 'ahooks' -import { produce } from 'immer' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Input from '@/app/components/base/input' @@ -31,21 +28,21 @@ import VarReferencePicker from '@/app/components/workflow/nodes/_base/components import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import MixedVariableTextInput from '@/app/components/workflow/nodes/tool/components/mixed-variable-text-input' import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types' -import { VarType } from '@/app/components/workflow/types' import { cn } from '@/utils/classnames' +import { + createPickerProps, + getFieldFlags, + getFieldTitle, + mergeReasoningValue, + resolveTargetVarType, + updateInputAutoState, + updateReasoningValue, + updateVariableSelectorValue, + updateVariableTypeValue, +} from './reasoning-config-form.helpers' import SchemaModal from './schema-modal' -type ReasoningConfigInputValue = { - type?: VarKindType - value?: unknown -} | null - -type ReasoningConfigInput = { - value: ReasoningConfigInputValue - auto?: 0 | 1 -} - -export type ReasoningConfigValue = Record +export type ReasoningConfigValue = ReasoningConfigValueShape type Props = { value: ReasoningConfigValue @@ -66,79 +63,42 @@ const ReasoningConfigForm: React.FC = ({ }) => { const { t } = useTranslation() const language = useLanguage() - const getVarKindType = (type: string) => { - if (type === FormTypeEnum.file || type === FormTypeEnum.files) - return VarKindType.variable - if (type === FormTypeEnum.select || type === FormTypeEnum.checkbox || type === FormTypeEnum.textNumber || type === FormTypeEnum.array || type === FormTypeEnum.object) - return VarKindType.constant - if (type === FormTypeEnum.textInput || type === FormTypeEnum.secretInput) - return VarKindType.mixed - } const handleAutomatic = (key: string, val: boolean, type: string) => { - onChange({ - ...value, - [key]: { - value: val ? null : { type: getVarKindType(type), value: null }, - auto: val ? 1 : 0, - }, - }) + onChange(updateInputAutoState(value, key, val, type)) } + const handleTypeChange = useCallback((variable: string, defaultValue: unknown) => { return (newType: VarKindType) => { - const res = produce(value, (draft: ToolVarInputs) => { - draft[variable].value = { - type: newType, - value: newType === VarKindType.variable ? '' : defaultValue, - } - }) - onChange(res) + onChange(updateVariableTypeValue(value, variable, newType, defaultValue)) } }, [onChange, value]) + const handleValueChange = useCallback((variable: string, varType: string) => { return (newValue: unknown) => { - const res = produce(value, (draft: ToolVarInputs) => { - draft[variable].value = { - type: getVarKindType(varType), - value: newValue, - } - }) - onChange(res) + onChange(updateReasoningValue(value, variable, varType, newValue)) } }, [onChange, value]) + const handleAppChange = useCallback((variable: string) => { return (app: { app_id: string inputs: Record files?: unknown[] }) => { - const newValue = produce(value, (draft: ToolVarInputs) => { - draft[variable].value = app - }) - onChange(newValue) + onChange(updateReasoningValue(value, variable, FormTypeEnum.appSelector, app)) } }, [onChange, value]) + const handleModelChange = useCallback((variable: string) => { return (model: Record) => { - const newValue = produce(value, (draft: ToolVarInputs) => { - const currentValue = draft[variable].value as Record | undefined - draft[variable].value = { - ...currentValue, - ...model, - } - }) - onChange(newValue) + onChange(mergeReasoningValue(value, variable, model)) } }, [onChange, value]) + const handleVariableSelectorChange = useCallback((variable: string) => { return (newValue: ValueSelector | string) => { - const res = produce(value, (draft: ToolVarInputs) => { - draft[variable].value = { - type: VarKindType.variable, - value: newValue, - } - }) - onChange(res) + onChange(updateVariableSelectorValue(value, variable, newValue)) } }, [onChange, value]) @@ -165,6 +125,7 @@ const ReasoningConfigForm: React.FC = ({ options, } = schema const auto = value[variable]?.auto + const fieldTitle = getFieldTitle(label, language) const tooltipContent = (tooltip && ( = ({ /> )) const varInput = value[variable].value - const isString = type === FormTypeEnum.textInput || type === FormTypeEnum.secretInput - const isNumber = type === FormTypeEnum.textNumber - const isObject = type === FormTypeEnum.object - const isArray = type === FormTypeEnum.array - const isShowJSONEditor = isObject || isArray - const isFile = type === FormTypeEnum.file || type === FormTypeEnum.files - const isBoolean = type === FormTypeEnum.checkbox - const isSelect = type === FormTypeEnum.select - const isAppSelector = type === FormTypeEnum.appSelector - const isModelSelector = type === FormTypeEnum.modelSelector - const showTypeSwitch = isNumber || isObject || isArray - const isConstant = varInput?.type === VarKindType.constant || !varInput?.type - const showVariableSelector = isFile || varInput?.type === VarKindType.variable - const targetVarType = () => { - if (isString) - return VarType.string - else if (isNumber) - return VarType.number - else if (type === FormTypeEnum.files) - return VarType.arrayFile - else if (type === FormTypeEnum.file) - return VarType.file - else if (isBoolean) - return VarType.boolean - else if (isObject) - return VarType.object - else if (isArray) - return VarType.arrayObject - else - return VarType.string - } - const getFilterVar = () => { - if (isNumber) - return (varPayload: Var) => varPayload.type === VarType.number - else if (isString) - return (varPayload: Var) => [VarType.string, VarType.number, VarType.secret].includes(varPayload.type) - else if (isFile) - return (varPayload: Var) => [VarType.file, VarType.arrayFile].includes(varPayload.type) - else if (isBoolean) - return (varPayload: Var) => varPayload.type === VarType.boolean - else if (isObject) - return (varPayload: Var) => varPayload.type === VarType.object - else if (isArray) - return (varPayload: Var) => [VarType.array, VarType.arrayString, VarType.arrayNumber, VarType.arrayObject].includes(varPayload.type) - return undefined - } + const { + isString, + isNumber, + isShowJSONEditor, + isBoolean, + isSelect, + isAppSelector, + isModelSelector, + showTypeSwitch, + isConstant, + showVariableSelector, + } = getFieldFlags(type, varInput) + const pickerProps = createPickerProps({ + type, + value, + language, + schema, + }) return (
- {label[language] || label.en_US} + {fieldTitle} {required && ( * )} {tooltipContent} · - {targetVarType()} + {resolveTargetVarType(type)} {isShowJSONEditor && ( = ({ >
showSchema(input_schema as SchemaRoot, label[language] || label.en_US)} + onClick={() => showSchema(input_schema as SchemaRoot, fieldTitle)} >
@@ -295,12 +228,7 @@ const ReasoningConfigForm: React.FC = ({ { - if (option.show_on.length) - return option.show_on.every(showOnItem => value[showOnItem.variable]?.value?.value === showOnItem.value) - - return true - }).map(option => ({ value: option.value, name: option.label[language] || option.label.en_US }))} + items={pickerProps.selectItems} onSelect={item => handleValueChange(variable, type)(item.value as string)} placeholder={placeholder?.[language] || placeholder?.en_US} /> @@ -347,9 +275,9 @@ const ReasoningConfigForm: React.FC = ({ nodeId={nodeId} value={(varInput?.value as string | ValueSelector) || []} onChange={handleVariableSelectorChange(variable)} - filterVar={getFilterVar()} - schema={schema as Partial} - valueTypePlaceHolder={targetVarType()} + filterVar={pickerProps.filterVar} + schema={pickerProps.schema} + valueTypePlaceHolder={pickerProps.targetVarType} /> )}
diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/__tests__/index.spec.ts b/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/__tests__/index.spec.ts new file mode 100644 index 0000000000..33a05be1b8 --- /dev/null +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/__tests__/index.spec.ts @@ -0,0 +1,9 @@ +import { describe, expect, it } from 'vitest' +import { usePluginInstalledCheck, useToolSelectorState } from '../index' + +describe('tool-selector hooks index', () => { + it('re-exports the tool selector hooks', () => { + expect(usePluginInstalledCheck).toBeTypeOf('function') + expect(useToolSelectorState).toBeTypeOf('function') + }) +}) diff --git a/web/app/components/plugins/plugin-page/__tests__/context-provider.spec.tsx b/web/app/components/plugins/plugin-page/__tests__/context-provider.spec.tsx new file mode 100644 index 0000000000..476ab8e145 --- /dev/null +++ b/web/app/components/plugins/plugin-page/__tests__/context-provider.spec.tsx @@ -0,0 +1,76 @@ +import { fireEvent, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { useGlobalPublicStore } from '@/context/global-public-context' +import { renderWithNuqs } from '@/test/nuqs-testing' +import { usePluginPageContext } from '../context' +import { PluginPageContextProvider } from '../context-provider' + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: vi.fn(), +})) + +vi.mock('../../hooks', () => ({ + PLUGIN_PAGE_TABS_MAP: { + plugins: 'plugins', + marketplace: 'discover', + }, + usePluginPageTabs: () => [ + { value: 'plugins', text: 'Plugins' }, + { value: 'discover', text: 'Discover' }, + ], +})) + +const mockGlobalPublicStore = (enableMarketplace: boolean) => { + vi.mocked(useGlobalPublicStore).mockImplementation((selector) => { + const state = { systemFeatures: { enable_marketplace: enableMarketplace } } + return selector(state as Parameters[0]) + }) +} + +const Consumer = () => { + const currentPluginID = usePluginPageContext(v => v.currentPluginID) + const setCurrentPluginID = usePluginPageContext(v => v.setCurrentPluginID) + const options = usePluginPageContext(v => v.options) + + return ( +
+ {currentPluginID ?? 'none'} + {options.length} + +
+ ) +} + +describe('PluginPageContextProvider', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('filters out the marketplace tab when the feature is disabled', () => { + mockGlobalPublicStore(false) + + renderWithNuqs( + + + , + ) + + expect(screen.getByTestId('options-count')).toHaveTextContent('1') + }) + + it('keeps the query-state tab and updates the current plugin id', () => { + mockGlobalPublicStore(true) + + renderWithNuqs( + + + , + { searchParams: '?tab=discover' }, + ) + + fireEvent.click(screen.getByText('select plugin')) + + expect(screen.getByTestId('current-plugin')).toHaveTextContent('plugin-1') + expect(screen.getByTestId('options-count')).toHaveTextContent('2') + }) +}) diff --git a/web/app/components/plugins/plugin-page/__tests__/debug-info.spec.tsx b/web/app/components/plugins/plugin-page/__tests__/debug-info.spec.tsx new file mode 100644 index 0000000000..ceec84a286 --- /dev/null +++ b/web/app/components/plugins/plugin-page/__tests__/debug-info.spec.tsx @@ -0,0 +1,89 @@ +import { render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import DebugInfo from '../debug-info' + +const mockDebugKey = vi.hoisted(() => ({ + data: null as null | { key: string, host: string, port: number }, + isLoading: false, +})) + +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.example.com${path}`, +})) + +vi.mock('@/service/use-plugins', () => ({ + useDebugKey: () => mockDebugKey, +})) + +vi.mock('@/app/components/base/button', () => ({ + default: ({ children }: { children: React.ReactNode }) => , +})) + +vi.mock('@/app/components/base/tooltip', () => ({ + default: ({ + children, + disabled, + popupContent, + }: { + children: React.ReactNode + disabled?: boolean + popupContent: React.ReactNode + }) => ( +
+ {children} + {!disabled &&
{popupContent}
} +
+ ), +})) + +vi.mock('../../base/key-value-item', () => ({ + default: ({ + label, + value, + maskedValue, + }: { + label: string + value: string + maskedValue?: string + }) => ( +
+ {label} + : + {maskedValue || value} +
+ ), +})) + +describe('DebugInfo', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDebugKey.data = null + mockDebugKey.isLoading = false + }) + + it('renders nothing while the debug key is loading', () => { + mockDebugKey.isLoading = true + const { container } = render() + + expect(container.innerHTML).toBe('') + }) + + it('renders debug metadata and masks the key when info is available', () => { + mockDebugKey.data = { + host: '127.0.0.1', + port: 5001, + key: '12345678abcdefghijklmnopqrst87654321', + } + + render() + + expect(screen.getByTestId('debug-button')).toBeInTheDocument() + expect(screen.getByText('plugin.debugInfo.title')).toBeInTheDocument() + expect(screen.getByRole('link')).toHaveAttribute( + 'href', + 'https://docs.example.com/develop-plugin/features-and-specs/plugin-types/remote-debug-a-plugin', + ) + expect(screen.getByTestId('kv-URL')).toHaveTextContent('URL:127.0.0.1:5001') + expect(screen.getByTestId('kv-Key')).toHaveTextContent('Key:12345678********87654321') + }) +}) diff --git a/web/app/components/plugins/plugin-page/__tests__/install-plugin-dropdown.spec.tsx b/web/app/components/plugins/plugin-page/__tests__/install-plugin-dropdown.spec.tsx new file mode 100644 index 0000000000..d3b72ebe5b --- /dev/null +++ b/web/app/components/plugins/plugin-page/__tests__/install-plugin-dropdown.spec.tsx @@ -0,0 +1,156 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import InstallPluginDropdown from '../install-plugin-dropdown' + +let portalOpen = false +const { + mockSystemFeatures, +} = vi.hoisted(() => ({ + mockSystemFeatures: { + enable_marketplace: true, + plugin_installation_permission: { + restrict_to_marketplace_only: false, + }, + }, +})) + +vi.mock('@/config', () => ({ + SUPPORT_INSTALL_LOCAL_FILE_EXTENSIONS: '.difypkg,.zip', +})) + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: typeof mockSystemFeatures }) => unknown) => + selector({ systemFeatures: mockSystemFeatures }), +})) + +vi.mock('@/app/components/base/icons/src/vender/solid/files', () => ({ + FileZip: () => file, +})) + +vi.mock('@/app/components/base/icons/src/vender/solid/general', () => ({ + Github: () => github, +})) + +vi.mock('@/app/components/base/icons/src/vender/solid/mediaAndDevices', () => ({ + MagicBox: () => magic, +})) + +vi.mock('@/app/components/base/button', () => ({ + default: ({ children }: { children: React.ReactNode }) => {children}, +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', async () => { + const React = await import('react') + return { + PortalToFollowElem: ({ + open, + children, + }: { + open: boolean + children: React.ReactNode + }) => { + portalOpen = open + return
{children}
+ }, + PortalToFollowElemTrigger: ({ + children, + onClick, + }: { + children: React.ReactNode + onClick: () => void + }) => , + PortalToFollowElemContent: ({ + children, + }: { + children: React.ReactNode + }) => portalOpen ?
{children}
: null, + } +}) + +vi.mock('@/app/components/plugins/install-plugin/install-from-github', () => ({ + default: ({ onClose }: { onClose: () => void }) => ( +
+ +
+ ), +})) + +vi.mock('@/app/components/plugins/install-plugin/install-from-local-package', () => ({ + default: ({ + file, + onClose, + }: { + file: File + onClose: () => void + }) => ( +
+ {file.name} + +
+ ), +})) + +describe('InstallPluginDropdown', () => { + beforeEach(() => { + vi.clearAllMocks() + portalOpen = false + mockSystemFeatures.enable_marketplace = true + mockSystemFeatures.plugin_installation_permission.restrict_to_marketplace_only = false + }) + + it('shows all install methods when marketplace and custom installs are enabled', () => { + render() + + fireEvent.click(screen.getByTestId('dropdown-trigger')) + + expect(screen.getByText('plugin.installFrom')).toBeInTheDocument() + expect(screen.getByText('plugin.source.marketplace')).toBeInTheDocument() + expect(screen.getByText('plugin.source.github')).toBeInTheDocument() + expect(screen.getByText('plugin.source.local')).toBeInTheDocument() + }) + + it('shows only marketplace when installation is restricted', () => { + mockSystemFeatures.plugin_installation_permission.restrict_to_marketplace_only = true + + render() + + fireEvent.click(screen.getByTestId('dropdown-trigger')) + + expect(screen.getByText('plugin.source.marketplace')).toBeInTheDocument() + expect(screen.queryByText('plugin.source.github')).not.toBeInTheDocument() + expect(screen.queryByText('plugin.source.local')).not.toBeInTheDocument() + }) + + it('switches to marketplace when the marketplace action is selected', () => { + const onSwitchToMarketplaceTab = vi.fn() + render() + + fireEvent.click(screen.getByTestId('dropdown-trigger')) + fireEvent.click(screen.getByText('plugin.source.marketplace')) + + expect(onSwitchToMarketplaceTab).toHaveBeenCalledTimes(1) + }) + + it('opens the github installer when github is selected', () => { + render() + + fireEvent.click(screen.getByTestId('dropdown-trigger')) + fireEvent.click(screen.getByText('plugin.source.github')) + + expect(screen.getByTestId('github-modal')).toBeInTheDocument() + }) + + it('opens the local package installer when a file is selected', () => { + const { container } = render() + + fireEvent.click(screen.getByTestId('dropdown-trigger')) + fireEvent.change(container.querySelector('input[type="file"]')!, { + target: { + files: [new File(['content'], 'plugin.difypkg')], + }, + }) + + expect(screen.getByTestId('local-modal')).toBeInTheDocument() + expect(screen.getByText('plugin.difypkg')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/plugin-page/__tests__/plugins-panel.spec.tsx b/web/app/components/plugins/plugin-page/__tests__/plugins-panel.spec.tsx new file mode 100644 index 0000000000..bad857077a --- /dev/null +++ b/web/app/components/plugins/plugin-page/__tests__/plugins-panel.spec.tsx @@ -0,0 +1,200 @@ +import type { PluginDetail } from '../../types' +import { fireEvent, render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import PluginsPanel from '../plugins-panel' + +const mockState = vi.hoisted(() => ({ + filters: { + categories: [] as string[], + tags: [] as string[], + searchQuery: '', + }, + currentPluginID: undefined as string | undefined, +})) + +const mockSetFilters = vi.fn() +const mockSetCurrentPluginID = vi.fn() +const mockLoadNextPage = vi.fn() +const mockInvalidateInstalledPluginList = vi.fn() +const mockUseInstalledPluginList = vi.fn() +const mockPluginListWithLatestVersion = vi.fn<() => PluginDetail[]>(() => []) + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: () => 'en_US', +})) + +vi.mock('@/i18n-config', () => ({ + renderI18nObject: (value: Record, locale: string) => value[locale] || '', +})) + +vi.mock('@/service/use-plugins', () => ({ + useInstalledPluginList: () => mockUseInstalledPluginList(), + useInvalidateInstalledPluginList: () => mockInvalidateInstalledPluginList, +})) + +vi.mock('../../hooks', () => ({ + usePluginsWithLatestVersion: () => mockPluginListWithLatestVersion(), +})) + +vi.mock('../context', () => ({ + usePluginPageContext: (selector: (value: { + filters: typeof mockState.filters + setFilters: typeof mockSetFilters + currentPluginID: string | undefined + setCurrentPluginID: typeof mockSetCurrentPluginID + }) => unknown) => selector({ + filters: mockState.filters, + setFilters: mockSetFilters, + currentPluginID: mockState.currentPluginID, + setCurrentPluginID: mockSetCurrentPluginID, + }), +})) + +vi.mock('../filter-management', () => ({ + default: ({ onFilterChange }: { onFilterChange: (filters: typeof mockState.filters) => void }) => ( + + ), +})) + +vi.mock('../empty', () => ({ + default: () =>
empty
, +})) + +vi.mock('../list', () => ({ + default: ({ pluginList }: { pluginList: PluginDetail[] }) =>
{pluginList.map(plugin => plugin.plugin_id).join(',')}
, +})) + +vi.mock('@/app/components/plugins/plugin-detail-panel', () => ({ + default: ({ detail, onHide, onUpdate }: { + detail?: PluginDetail + onHide: () => void + onUpdate: () => void + }) => ( +
+ {detail?.plugin_id ?? 'none'} + + +
+ ), +})) + +const createPlugin = (pluginId: string, label: string, tags: string[] = []): PluginDetail => ({ + id: pluginId, + created_at: '2024-01-01', + updated_at: '2024-01-02', + name: label, + plugin_id: pluginId, + plugin_unique_identifier: `${pluginId}-uid`, + declaration: { + category: 'tool', + name: pluginId, + label: { en_US: label }, + description: { en_US: `${label} description` }, + tags, + } as PluginDetail['declaration'], + installation_id: `${pluginId}-install`, + tenant_id: 'tenant-1', + endpoints_setups: 0, + endpoints_active: 0, + version: '1.0.0', + latest_version: '1.0.0', + latest_unique_identifier: `${pluginId}-uid`, + source: 'marketplace' as PluginDetail['source'], + status: 'active', + deprecated_reason: '', + alternative_plugin_id: '', +}) as PluginDetail + +describe('PluginsPanel', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.useFakeTimers() + mockState.filters = { categories: [], tags: [], searchQuery: '' } + mockState.currentPluginID = undefined + mockUseInstalledPluginList.mockReturnValue({ + data: { plugins: [] }, + isLoading: false, + isFetching: false, + isLastPage: true, + loadNextPage: mockLoadNextPage, + }) + mockPluginListWithLatestVersion.mockReturnValue([]) + }) + + afterEach(() => { + vi.useRealTimers() + }) + + it('renders the loading state while the plugin list is pending', () => { + mockUseInstalledPluginList.mockReturnValue({ + data: { plugins: [] }, + isLoading: true, + isFetching: false, + isLastPage: true, + loadNextPage: mockLoadNextPage, + }) + + render() + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('filters the list and exposes the load-more action', () => { + mockState.filters.searchQuery = 'alpha' + mockPluginListWithLatestVersion.mockReturnValue([ + createPlugin('alpha-tool', 'Alpha Tool', ['search']), + createPlugin('beta-tool', 'Beta Tool', ['rag']), + ]) + mockUseInstalledPluginList.mockReturnValue({ + data: { plugins: [] }, + isLoading: false, + isFetching: false, + isLastPage: false, + loadNextPage: mockLoadNextPage, + }) + + render() + + expect(screen.getByTestId('plugin-list')).toHaveTextContent('alpha-tool') + expect(screen.queryByText('beta-tool')).not.toBeInTheDocument() + + fireEvent.click(screen.getByText('workflow.common.loadMore')) + fireEvent.click(screen.getByTestId('filter-management')) + vi.runAllTimers() + + expect(mockLoadNextPage).toHaveBeenCalled() + expect(mockSetFilters).toHaveBeenCalledWith({ + categories: [], + tags: [], + searchQuery: 'beta', + }) + }) + + it('renders the empty state and keeps the current plugin detail in sync', () => { + mockState.currentPluginID = 'beta-tool' + mockState.filters.searchQuery = 'missing' + mockPluginListWithLatestVersion.mockReturnValue([ + createPlugin('beta-tool', 'Beta Tool'), + ]) + + render() + + expect(screen.getByTestId('empty-state')).toBeInTheDocument() + expect(screen.getByTestId('plugin-detail-panel')).toHaveTextContent('beta-tool') + + fireEvent.click(screen.getByText('hide detail')) + fireEvent.click(screen.getByText('refresh detail')) + + expect(mockSetCurrentPluginID).toHaveBeenCalledWith(undefined) + expect(mockInvalidateInstalledPluginList).toHaveBeenCalled() + }) +}) diff --git a/web/app/components/plugins/plugin-page/filter-management/__tests__/constant.spec.ts b/web/app/components/plugins/plugin-page/filter-management/__tests__/constant.spec.ts new file mode 100644 index 0000000000..7286ff549f --- /dev/null +++ b/web/app/components/plugins/plugin-page/filter-management/__tests__/constant.spec.ts @@ -0,0 +1,32 @@ +import type { Category, Tag } from '../constant' +import { describe, expect, it } from 'vitest' + +describe('filter-management constant types', () => { + it('accepts tag objects with binding counts', () => { + const tag: Tag = { + id: 'tag-1', + name: 'search', + type: 'plugin', + binding_count: 3, + } + + expect(tag).toEqual({ + id: 'tag-1', + name: 'search', + type: 'plugin', + binding_count: 3, + }) + }) + + it('accepts supported category names', () => { + const category: Category = { + name: 'tool', + binding_count: 8, + } + + expect(category).toEqual({ + name: 'tool', + binding_count: 8, + }) + }) +}) diff --git a/web/app/components/plugins/plugin-page/filter-management/__tests__/tag-filter.spec.tsx b/web/app/components/plugins/plugin-page/filter-management/__tests__/tag-filter.spec.tsx new file mode 100644 index 0000000000..ff3cd3d97c --- /dev/null +++ b/web/app/components/plugins/plugin-page/filter-management/__tests__/tag-filter.spec.tsx @@ -0,0 +1,76 @@ +import { fireEvent, render, screen, within } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import TagFilter from '../tag-filter' + +let portalOpen = false + +vi.mock('../../../hooks', () => ({ + useTags: () => ({ + tags: [ + { name: 'agent', label: 'Agent' }, + { name: 'rag', label: 'RAG' }, + { name: 'search', label: 'Search' }, + ], + getTagLabel: (name: string) => ({ + agent: 'Agent', + rag: 'RAG', + search: 'Search', + }[name] ?? name), + }), +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ + children, + open, + }: { + children: React.ReactNode + open: boolean + }) => { + portalOpen = open + return
{children}
+ }, + PortalToFollowElemTrigger: ({ + children, + onClick, + }: { + children: React.ReactNode + onClick: () => void + }) => , + PortalToFollowElemContent: ({ + children, + }: { + children: React.ReactNode + }) => portalOpen ?
{children}
: null, +})) + +describe('TagFilter', () => { + beforeEach(() => { + vi.clearAllMocks() + portalOpen = false + }) + + it('renders selected tag labels and the overflow counter', () => { + render() + + expect(screen.getByText('Agent,RAG')).toBeInTheDocument() + expect(screen.getByText('+1')).toBeInTheDocument() + }) + + it('filters options by search text and toggles tag selection', () => { + const onChange = vi.fn() + render() + + fireEvent.click(screen.getByTestId('trigger')) + const portal = screen.getByTestId('portal-content') + + fireEvent.change(screen.getByPlaceholderText('pluginTags.searchTags'), { target: { value: 'ra' } }) + + expect(within(portal).queryByText('Agent')).not.toBeInTheDocument() + expect(within(portal).getByText('RAG')).toBeInTheDocument() + + fireEvent.click(within(portal).getByText('RAG')) + + expect(onChange).toHaveBeenCalledWith(['agent', 'rag']) + }) +}) diff --git a/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/config.spec.ts b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/config.spec.ts new file mode 100644 index 0000000000..36450a4386 --- /dev/null +++ b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/config.spec.ts @@ -0,0 +1,15 @@ +import { describe, expect, it } from 'vitest' +import { defaultValue } from '../config' +import { AUTO_UPDATE_MODE, AUTO_UPDATE_STRATEGY } from '../types' + +describe('auto-update config', () => { + it('provides the expected default auto update value', () => { + expect(defaultValue).toEqual({ + strategy_setting: AUTO_UPDATE_STRATEGY.disabled, + upgrade_time_of_day: 0, + upgrade_mode: AUTO_UPDATE_MODE.update_all, + exclude_plugins: [], + include_plugins: [], + }) + }) +}) diff --git a/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/no-data-placeholder.spec.tsx b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/no-data-placeholder.spec.tsx new file mode 100644 index 0000000000..d205682690 --- /dev/null +++ b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/no-data-placeholder.spec.tsx @@ -0,0 +1,19 @@ +import { render, screen } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import NoDataPlaceholder from '../no-data-placeholder' + +describe('NoDataPlaceholder', () => { + it('renders the no-found state by default', () => { + const { container } = render() + + expect(container.querySelector('svg')).toBeInTheDocument() + expect(screen.getByText('plugin.autoUpdate.noPluginPlaceholder.noFound')).toBeInTheDocument() + }) + + it('renders the no-installed state when noPlugins is true', () => { + const { container } = render() + + expect(container.querySelector('svg')).toBeInTheDocument() + expect(screen.getByText('plugin.autoUpdate.noPluginPlaceholder.noInstalled')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/no-plugin-selected.spec.tsx b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/no-plugin-selected.spec.tsx new file mode 100644 index 0000000000..ba172ad3d6 --- /dev/null +++ b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/no-plugin-selected.spec.tsx @@ -0,0 +1,18 @@ +import { render, screen } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import NoPluginSelected from '../no-plugin-selected' +import { AUTO_UPDATE_MODE } from '../types' + +describe('NoPluginSelected', () => { + it('renders partial mode placeholder', () => { + render() + + expect(screen.getByText('plugin.autoUpdate.upgradeModePlaceholder.partial')).toBeInTheDocument() + }) + + it('renders exclude mode placeholder', () => { + render() + + expect(screen.getByText('plugin.autoUpdate.upgradeModePlaceholder.exclude')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/plugins-picker.spec.tsx b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/plugins-picker.spec.tsx new file mode 100644 index 0000000000..4330f35bb4 --- /dev/null +++ b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/plugins-picker.spec.tsx @@ -0,0 +1,82 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import PluginsPicker from '../plugins-picker' +import { AUTO_UPDATE_MODE } from '../types' + +const mockToolPicker = vi.fn() + +vi.mock('@/app/components/base/button', () => ({ + default: ({ + children, + }: { + children: React.ReactNode + }) => , +})) + +vi.mock('../no-plugin-selected', () => ({ + default: ({ updateMode }: { updateMode: AUTO_UPDATE_MODE }) =>
{updateMode}
, +})) + +vi.mock('../plugins-selected', () => ({ + default: ({ plugins }: { plugins: string[] }) =>
{plugins.join(',')}
, +})) + +vi.mock('../tool-picker', () => ({ + default: (props: Record) => { + mockToolPicker(props) + return
tool-picker
+ }, +})) + +describe('PluginsPicker', () => { + it('renders the empty state when no plugins are selected', () => { + render( + , + ) + + expect(screen.getByTestId('no-plugin-selected')).toHaveTextContent(AUTO_UPDATE_MODE.partial) + expect(screen.queryByTestId('plugins-selected')).not.toBeInTheDocument() + expect(mockToolPicker).toHaveBeenCalledWith(expect.objectContaining({ + value: [], + isShow: false, + onShowChange: expect.any(Function), + })) + }) + + it('renders selected plugins summary and clears them', () => { + const onChange = vi.fn() + render( + , + ) + + expect(screen.getByText('plugin.autoUpdate.excludeUpdate:{"num":2}')).toBeInTheDocument() + expect(screen.getByTestId('plugins-selected')).toHaveTextContent('dify/plugin-1,dify/plugin-2') + + fireEvent.click(screen.getByText('plugin.autoUpdate.operation.clearAll')) + + expect(onChange).toHaveBeenCalledWith([]) + }) + + it('passes the select button trigger into ToolPicker', () => { + render( + , + ) + + expect(screen.getByTestId('tool-picker')).toBeInTheDocument() + expect(mockToolPicker).toHaveBeenCalledWith(expect.objectContaining({ + trigger: expect.anything(), + })) + }) +}) diff --git a/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/plugins-selected.spec.tsx b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/plugins-selected.spec.tsx new file mode 100644 index 0000000000..cc4693f89c --- /dev/null +++ b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/plugins-selected.spec.tsx @@ -0,0 +1,29 @@ +import { render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import PluginsSelected from '../plugins-selected' + +vi.mock('@/config', () => ({ + MARKETPLACE_API_PREFIX: 'https://marketplace.example.com', +})) + +vi.mock('@/app/components/plugins/card/base/card-icon', () => ({ + default: ({ src }: { src: string }) =>
{src}
, +})) + +describe('PluginsSelected', () => { + it('renders all selected plugin icons when the count is below the limit', () => { + render() + + expect(screen.getAllByTestId('plugin-icon')).toHaveLength(2) + expect(screen.getByText('https://marketplace.example.com/plugins/dify/plugin-1/icon')).toBeInTheDocument() + expect(screen.queryByText('+1')).not.toBeInTheDocument() + }) + + it('renders the overflow badge when more than fourteen plugins are selected', () => { + const plugins = Array.from({ length: 16 }, (_, index) => `dify/plugin-${index}`) + render() + + expect(screen.getAllByTestId('plugin-icon')).toHaveLength(14) + expect(screen.getByText('+2')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/strategy-picker.spec.tsx b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/strategy-picker.spec.tsx new file mode 100644 index 0000000000..aec57a2739 --- /dev/null +++ b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/strategy-picker.spec.tsx @@ -0,0 +1,100 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import StrategyPicker from '../strategy-picker' +import { AUTO_UPDATE_STRATEGY } from '../types' + +let portalOpen = false + +vi.mock('@/app/components/base/button', () => ({ + default: ({ + children, + }: { + children: React.ReactNode + }) => {children}, +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', async () => { + const React = await import('react') + return { + PortalToFollowElem: ({ + open, + children, + }: { + open: boolean + children: React.ReactNode + }) => { + portalOpen = open + return
{children}
+ }, + PortalToFollowElemTrigger: ({ + children, + onClick, + }: { + children: React.ReactNode + onClick: (event: { stopPropagation: () => void, nativeEvent: { stopImmediatePropagation: () => void } }) => void + }) => ( + + ), + PortalToFollowElemContent: ({ + children, + }: { + children: React.ReactNode + }) => portalOpen ?
{children}
: null, + } +}) + +describe('StrategyPicker', () => { + beforeEach(() => { + vi.clearAllMocks() + portalOpen = false + }) + + it('renders the selected strategy label in the trigger', () => { + render( + , + ) + + expect(screen.getByTestId('trigger')).toHaveTextContent('plugin.autoUpdate.strategy.fixOnly.name') + }) + + it('opens the option list when the trigger is clicked', () => { + render( + , + ) + + fireEvent.click(screen.getByTestId('trigger')) + + expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content').querySelectorAll('svg')).toHaveLength(1) + expect(screen.getByText('plugin.autoUpdate.strategy.latest.description')).toBeInTheDocument() + }) + + it('calls onChange when a new strategy is selected', () => { + const onChange = vi.fn() + render( + , + ) + + fireEvent.click(screen.getByTestId('trigger')) + fireEvent.click(screen.getByText('plugin.autoUpdate.strategy.latest.name')) + + expect(onChange).toHaveBeenCalledWith(AUTO_UPDATE_STRATEGY.latest) + }) +}) diff --git a/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/tool-item.spec.tsx b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/tool-item.spec.tsx new file mode 100644 index 0000000000..f15fe5933f --- /dev/null +++ b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/tool-item.spec.tsx @@ -0,0 +1,65 @@ +import type { PluginDetail } from '@/app/components/plugins/types' +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import ToolItem from '../tool-item' + +vi.mock('@/config', () => ({ + MARKETPLACE_API_PREFIX: 'https://marketplace.example.com', +})) + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: () => 'en_US', +})) + +vi.mock('@/i18n-config', () => ({ + renderI18nObject: (value: Record, language: string) => value[language], +})) + +vi.mock('@/app/components/plugins/card/base/card-icon', () => ({ + default: ({ src }: { src: string }) =>
{src}
, +})) + +vi.mock('@/app/components/base/checkbox', () => ({ + default: ({ + checked, + onCheck, + }: { + checked?: boolean + onCheck: () => void + }) => ( + + ), +})) + +const payload = { + plugin_id: 'dify/plugin-1', + declaration: { + label: { + en_US: 'Plugin One', + zh_Hans: 'Plugin One', + }, + author: 'Dify', + }, +} as PluginDetail + +describe('ToolItem', () => { + it('renders plugin metadata and marketplace icon', () => { + render() + + expect(screen.getByText('Plugin One')).toBeInTheDocument() + expect(screen.getByText('Dify')).toBeInTheDocument() + expect(screen.getByText('https://marketplace.example.com/plugins/dify/plugin-1/icon')).toBeInTheDocument() + expect(screen.getByText('true')).toBeInTheDocument() + }) + + it('calls onCheckChange when checkbox is clicked', () => { + const onCheckChange = vi.fn() + render() + + fireEvent.click(screen.getByTestId('checkbox')) + + expect(onCheckChange).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/tool-picker.spec.tsx b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/tool-picker.spec.tsx new file mode 100644 index 0000000000..9e63622d3f --- /dev/null +++ b/web/app/components/plugins/reference-setting-modal/auto-update-setting/__tests__/tool-picker.spec.tsx @@ -0,0 +1,248 @@ +import type { PluginDetail } from '@/app/components/plugins/types' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PluginSource } from '@/app/components/plugins/types' +import ToolPicker from '../tool-picker' + +let portalOpen = false + +const mockInstalledPluginList = vi.hoisted(() => ({ + data: { + plugins: [] as PluginDetail[], + }, + isLoading: false, +})) + +vi.mock('@/service/use-plugins', () => ({ + useInstalledPluginList: () => mockInstalledPluginList, +})) + +vi.mock('@/app/components/base/loading', () => ({ + default: () =>
loading
, +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', async () => { + const React = await import('react') + return { + PortalToFollowElem: ({ + open, + children, + }: { + open: boolean + children: React.ReactNode + }) => { + portalOpen = open + return
{children}
+ }, + PortalToFollowElemTrigger: ({ + children, + onClick, + }: { + children: React.ReactNode + onClick: () => void + }) => , + PortalToFollowElemContent: ({ + children, + className, + }: { + children: React.ReactNode + className?: string + }) => portalOpen ?
{children}
: null, + } +}) + +vi.mock('@/app/components/plugins/marketplace/search-box', () => ({ + default: ({ + search, + tags, + onSearchChange, + onTagsChange, + placeholder, + }: { + search: string + tags: string[] + onSearchChange: (value: string) => void + onTagsChange: (value: string[]) => void + placeholder: string + }) => ( +
+
{placeholder}
+
{search}
+
{tags.join(',')}
+ + +
+ ), +})) + +vi.mock('../no-data-placeholder', () => ({ + default: ({ + noPlugins, + }: { + noPlugins?: boolean + }) =>
{String(noPlugins)}
, +})) + +vi.mock('../tool-item', () => ({ + default: ({ + payload, + isChecked, + onCheckChange, + }: { + payload: PluginDetail + isChecked?: boolean + onCheckChange: () => void + }) => ( +
+ {payload.plugin_id} + {String(isChecked)} + +
+ ), +})) + +const createPlugin = ( + pluginId: string, + source: PluginDetail['source'], + category: string, + tags: string[], +): PluginDetail => ({ + plugin_id: pluginId, + source, + declaration: { + category, + tags, + }, +} as PluginDetail) + +describe('ToolPicker', () => { + beforeEach(() => { + vi.clearAllMocks() + portalOpen = false + mockInstalledPluginList.data = { + plugins: [], + } + mockInstalledPluginList.isLoading = false + }) + + it('toggles popup visibility from the trigger', () => { + const onShowChange = vi.fn() + render( + trigger} + value={[]} + onChange={vi.fn()} + isShow={false} + onShowChange={onShowChange} + />, + ) + + fireEvent.click(screen.getByTestId('trigger')) + + expect(onShowChange).toHaveBeenCalledWith(true) + }) + + it('renders loading content while installed plugins are loading', () => { + mockInstalledPluginList.isLoading = true + + render( + trigger} + value={[]} + onChange={vi.fn()} + isShow + onShowChange={vi.fn()} + />, + ) + + expect(screen.getByTestId('loading')).toBeInTheDocument() + }) + + it('renders no-data placeholder when there are no matching marketplace plugins', () => { + render( + trigger} + value={[]} + onChange={vi.fn()} + isShow + onShowChange={vi.fn()} + />, + ) + + expect(screen.getByTestId('no-data')).toHaveTextContent('true') + }) + + it('filters by plugin type, tags, and query', () => { + mockInstalledPluginList.data = { + plugins: [ + createPlugin('tool-search', PluginSource.marketplace, 'tool', ['search']), + createPlugin('tool-rag', PluginSource.marketplace, 'tool', ['rag']), + createPlugin('model-agent', PluginSource.marketplace, 'model', ['agent']), + createPlugin('github-tool', PluginSource.github, 'tool', ['rag']), + ], + } + + render( + trigger} + value={[]} + onChange={vi.fn()} + isShow + onShowChange={vi.fn()} + />, + ) + + expect(screen.getAllByTestId('tool-item')).toHaveLength(3) + expect(screen.queryByText('github-tool')).not.toBeInTheDocument() + + fireEvent.click(screen.getByText('plugin.category.models')) + expect(screen.getAllByTestId('tool-item')).toHaveLength(1) + expect(screen.getByText('model-agent')).toBeInTheDocument() + + fireEvent.click(screen.getByText('plugin.category.tools')) + expect(screen.getAllByTestId('tool-item')).toHaveLength(2) + + fireEvent.click(screen.getByTestId('set-tags')) + expect(screen.getAllByTestId('tool-item')).toHaveLength(1) + expect(screen.getByText('tool-rag')).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('set-query')) + expect(screen.getAllByTestId('tool-item')).toHaveLength(1) + expect(screen.getByTestId('search-state')).toHaveTextContent('tool-rag') + }) + + it('adds and removes plugin ids from the selection', () => { + mockInstalledPluginList.data = { + plugins: [ + createPlugin('tool-rag', PluginSource.marketplace, 'tool', ['rag']), + createPlugin('tool-search', PluginSource.marketplace, 'tool', ['search']), + ], + } + const onChange = vi.fn() + const { rerender } = render( + trigger} + value={['tool-rag']} + onChange={onChange} + isShow + onShowChange={vi.fn()} + />, + ) + + fireEvent.click(screen.getByTestId('toggle-tool-search')) + expect(onChange).toHaveBeenCalledWith(['tool-rag', 'tool-search']) + + rerender( + trigger} + value={['tool-rag']} + onChange={onChange} + isShow + onShowChange={vi.fn()} + />, + ) + + fireEvent.click(screen.getByTestId('toggle-tool-rag')) + expect(onChange).toHaveBeenCalledWith([]) + }) +}) diff --git a/web/app/components/plugins/update-plugin/__tests__/from-market-place.spec.tsx b/web/app/components/plugins/update-plugin/__tests__/from-market-place.spec.tsx new file mode 100644 index 0000000000..b66ab20a45 --- /dev/null +++ b/web/app/components/plugins/update-plugin/__tests__/from-market-place.spec.tsx @@ -0,0 +1,226 @@ +import type { UpdateFromMarketPlacePayload } from '@/app/components/plugins/types' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PluginCategoryEnum, TaskStatus } from '@/app/components/plugins/types' +import UpdateFromMarketplace from '../from-market-place' + +const { + mockStop, + mockCheck, + mockHandleRefetch, + mockInvalidateReferenceSettings, + mockRemoveAutoUpgrade, + mockUpdateFromMarketPlace, + mockToastError, +} = vi.hoisted(() => ({ + mockStop: vi.fn(), + mockCheck: vi.fn(), + mockHandleRefetch: vi.fn(), + mockInvalidateReferenceSettings: vi.fn(), + mockRemoveAutoUpgrade: vi.fn(), + mockUpdateFromMarketPlace: vi.fn(), + mockToastError: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/dialog', () => ({ + Dialog: ({ children }: { children: React.ReactNode }) =>
{children}
, + DialogContent: ({ children }: { children: React.ReactNode }) =>
{children}
, + DialogTitle: ({ children }: { children: React.ReactNode }) =>
{children}
, + DialogCloseButton: () => , +})) + +vi.mock('@/app/components/base/badge/index', () => ({ + __esModule: true, + BadgeState: { + Warning: 'warning', + }, + default: ({ children }: { children: React.ReactNode }) =>
{children}
, +})) + +vi.mock('@/app/components/base/button', () => ({ + default: ({ + children, + onClick, + disabled, + }: { + children: React.ReactNode + onClick?: () => void + disabled?: boolean + }) => , +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, + }, +})) + +vi.mock('@/app/components/plugins/card', () => ({ + default: ({ titleLeft, payload }: { titleLeft: React.ReactNode, payload: { label: Record } }) => ( +
+
{payload.label.en_US}
+
{titleLeft}
+
+ ), +})) + +vi.mock('@/app/components/plugins/install-plugin/base/check-task-status', () => ({ + default: () => ({ + check: mockCheck, + stop: mockStop, + }), +})) + +vi.mock('@/app/components/plugins/install-plugin/utils', () => ({ + pluginManifestToCardPluginProps: (payload: unknown) => payload, +})) + +vi.mock('@/service/plugins', () => ({ + updateFromMarketPlace: mockUpdateFromMarketPlace, +})) + +vi.mock('@/service/use-plugins', () => ({ + usePluginTaskList: () => ({ + handleRefetch: mockHandleRefetch, + }), + useRemoveAutoUpgrade: () => ({ + mutateAsync: mockRemoveAutoUpgrade, + }), + useInvalidateReferenceSettings: () => mockInvalidateReferenceSettings, +})) + +vi.mock('../install-plugin/base/use-get-icon', () => ({ + default: () => ({ + getIconUrl: async (icon: string) => `https://cdn.example.com/${icon}`, + }), +})) + +vi.mock('../downgrade-warning', () => ({ + default: ({ + onCancel, + onJustDowngrade, + onExcludeAndDowngrade, + }: { + onCancel: () => void + onJustDowngrade: () => void + onExcludeAndDowngrade: () => void + }) => ( +
+ + + +
+ ), +})) + +const createPayload = (overrides: Partial = {}): UpdateFromMarketPlacePayload => ({ + category: PluginCategoryEnum.tool, + originalPackageInfo: { + id: 'plugin@1.0.0', + payload: { + version: '1.0.0', + icon: 'plugin.png', + label: { en_US: 'Plugin Label' }, + } as UpdateFromMarketPlacePayload['originalPackageInfo']['payload'], + }, + targetPackageInfo: { + id: 'plugin@2.0.0', + version: '2.0.0', + }, + ...overrides, +}) + +describe('UpdateFromMarketplace', () => { + beforeEach(() => { + vi.clearAllMocks() + mockCheck.mockResolvedValue({ status: TaskStatus.success }) + mockUpdateFromMarketPlace.mockResolvedValue({ + all_installed: true, + task_id: 'task-1', + }) + }) + + it('renders the upgrade modal content and current version transition', async () => { + render( + , + ) + + expect(screen.getByText('plugin.upgrade.title')).toBeInTheDocument() + expect(screen.getByText('plugin.upgrade.description')).toBeInTheDocument() + expect(screen.getByText('1.0.0 -> 2.0.0')).toBeInTheDocument() + await waitFor(() => { + expect(screen.getByTestId('plugin-card')).toHaveTextContent('Plugin Label') + }) + }) + + it('submits the marketplace upgrade and calls onSave when installation is immediate', async () => { + const onSave = vi.fn() + render( + , + ) + + fireEvent.click(screen.getByText('plugin.upgrade.upgrade')) + + await waitFor(() => { + expect(mockUpdateFromMarketPlace).toHaveBeenCalledWith({ + original_plugin_unique_identifier: 'plugin@1.0.0', + new_plugin_unique_identifier: 'plugin@2.0.0', + }) + expect(onSave).toHaveBeenCalled() + }) + }) + + it('surfaces failed upgrade messages from the response task payload', async () => { + mockUpdateFromMarketPlace.mockResolvedValue({ + task: { + status: TaskStatus.failed, + plugins: [{ + plugin_unique_identifier: 'plugin@2.0.0', + message: 'upgrade failed', + }], + }, + }) + + render( + , + ) + + fireEvent.click(screen.getByText('plugin.upgrade.upgrade')) + + await waitFor(() => { + expect(mockToastError).toHaveBeenCalledWith('upgrade failed') + }) + }) + + it('removes auto-upgrade before downgrading when the warning modal is shown', async () => { + render( + , + ) + + fireEvent.click(screen.getByText('exclude and downgrade')) + + await waitFor(() => { + expect(mockRemoveAutoUpgrade).toHaveBeenCalledWith({ plugin_id: 'plugin-1' }) + expect(mockInvalidateReferenceSettings).toHaveBeenCalled() + expect(mockUpdateFromMarketPlace).toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/plugins/update-plugin/__tests__/plugin-version-picker.spec.tsx b/web/app/components/plugins/update-plugin/__tests__/plugin-version-picker.spec.tsx new file mode 100644 index 0000000000..b65c6a6e42 --- /dev/null +++ b/web/app/components/plugins/update-plugin/__tests__/plugin-version-picker.spec.tsx @@ -0,0 +1,107 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import PluginVersionPicker from '../plugin-version-picker' + +type VersionItem = { + version: string + unique_identifier: string + created_at: string +} + +const mockVersionList = vi.hoisted(() => ({ + data: { + versions: [] as VersionItem[], + }, +})) + +vi.mock('@/hooks/use-timestamp', () => ({ + default: () => ({ + formatDate: (value: string, format: string) => `${value}:${format}`, + }), +})) + +vi.mock('@/service/use-plugins', () => ({ + useVersionListOfPlugin: () => ({ + data: mockVersionList, + }), +})) + +describe('PluginVersionPicker', () => { + beforeEach(() => { + vi.clearAllMocks() + mockVersionList.data.versions = [ + { + version: '2.0.0', + unique_identifier: 'uid-current', + created_at: '2024-01-02', + }, + { + version: '1.0.0', + unique_identifier: 'uid-old', + created_at: '2023-12-01', + }, + ] + }) + + it('renders version options and highlights the current version', () => { + render( + trigger} + onSelect={vi.fn()} + />, + ) + + expect(screen.getByText('plugin.detailPanel.switchVersion')).toBeInTheDocument() + expect(screen.getByText('2.0.0')).toBeInTheDocument() + expect(screen.getByText('2024-01-02:appLog.dateTimeFormat')).toBeInTheDocument() + expect(screen.getByText('CURRENT')).toBeInTheDocument() + }) + + it('calls onSelect with downgrade metadata and closes the picker', () => { + const onSelect = vi.fn() + const onShowChange = vi.fn() + + render( + trigger} + onSelect={onSelect} + />, + ) + + fireEvent.click(screen.getByText('1.0.0')) + + expect(onSelect).toHaveBeenCalledWith({ + version: '1.0.0', + unique_identifier: 'uid-old', + isDowngrade: true, + }) + expect(onShowChange).toHaveBeenCalledWith(false) + }) + + it('does not call onSelect when the current version is clicked', () => { + const onSelect = vi.fn() + + render( + trigger} + onSelect={onSelect} + />, + ) + + fireEvent.click(screen.getByText('2.0.0')) + + expect(onSelect).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/rag-pipeline/components/__tests__/rag-pipeline-children.spec.tsx b/web/app/components/rag-pipeline/components/__tests__/rag-pipeline-children.spec.tsx new file mode 100644 index 0000000000..fcb208fc67 --- /dev/null +++ b/web/app/components/rag-pipeline/components/__tests__/rag-pipeline-children.spec.tsx @@ -0,0 +1,141 @@ +import type { EnvironmentVariable } from '@/app/components/workflow/types' +import { act, fireEvent, render, screen } from '@testing-library/react' +import { DSL_EXPORT_CHECK } from '@/app/components/workflow/constants' +import RagPipelineChildren from '../rag-pipeline-children' + +let mockShowImportDSLModal = false +let mockSubscription: ((value: { type: string, payload?: { data?: EnvironmentVariable[] } }) => void) | null = null + +const { + mockSetShowImportDSLModal, + mockHandlePaneContextmenuCancel, + mockExportCheck, + mockHandleExportDSL, + mockUseRagPipelineSearch, +} = vi.hoisted(() => ({ + mockSetShowImportDSLModal: vi.fn((value: boolean) => { + mockShowImportDSLModal = value + }), + mockHandlePaneContextmenuCancel: vi.fn(), + mockExportCheck: vi.fn(), + mockHandleExportDSL: vi.fn(), + mockUseRagPipelineSearch: vi.fn(), +})) + +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + useSubscription: (callback: (value: { type: string, payload?: { data?: EnvironmentVariable[] } }) => void) => { + mockSubscription = callback + }, + }, + }), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { + showImportDSLModal: boolean + setShowImportDSLModal: typeof mockSetShowImportDSLModal + }) => unknown) => selector({ + showImportDSLModal: mockShowImportDSLModal, + setShowImportDSLModal: mockSetShowImportDSLModal, + }), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useDSL: () => ({ + exportCheck: mockExportCheck, + handleExportDSL: mockHandleExportDSL, + }), + usePanelInteractions: () => ({ + handlePaneContextmenuCancel: mockHandlePaneContextmenuCancel, + }), +})) + +vi.mock('../../hooks/use-rag-pipeline-search', () => ({ + useRagPipelineSearch: mockUseRagPipelineSearch, +})) + +vi.mock('../../../workflow/plugin-dependency', () => ({ + default: () =>
, +})) + +vi.mock('../panel', () => ({ + default: () =>
, +})) + +vi.mock('../publish-toast', () => ({ + default: () =>
, +})) + +vi.mock('../rag-pipeline-header', () => ({ + default: () =>
, +})) + +vi.mock('../update-dsl-modal', () => ({ + default: ({ onCancel }: { onCancel: () => void }) => ( +
+ +
+ ), +})) + +vi.mock('@/app/components/workflow/dsl-export-confirm-modal', () => ({ + default: ({ + envList, + onConfirm, + onClose, + }: { + envList: EnvironmentVariable[] + onConfirm: () => void + onClose: () => void + }) => ( +
+
{envList.map(env => env.name).join(',')}
+ + +
+ ), +})) + +describe('RagPipelineChildren', () => { + beforeEach(() => { + vi.clearAllMocks() + mockShowImportDSLModal = false + mockSubscription = null + }) + + it('should render the main pipeline children and the import modal when enabled', () => { + mockShowImportDSLModal = true + + render() + + fireEvent.click(screen.getByText('close import')) + + expect(mockUseRagPipelineSearch).toHaveBeenCalledTimes(1) + expect(screen.getByTestId('plugin-dependency')).toBeInTheDocument() + expect(screen.getByTestId('rag-header')).toBeInTheDocument() + expect(screen.getByTestId('rag-panel')).toBeInTheDocument() + expect(screen.getByTestId('publish-toast')).toBeInTheDocument() + expect(screen.getByTestId('update-dsl-modal')).toBeInTheDocument() + expect(mockSetShowImportDSLModal).toHaveBeenCalledWith(false) + }) + + it('should show the DSL export confirmation modal after receiving the export event', () => { + render() + + act(() => { + mockSubscription?.({ + type: DSL_EXPORT_CHECK, + payload: { + data: [{ name: 'API_KEY' } as EnvironmentVariable], + }, + }) + }) + + fireEvent.click(screen.getByText('confirm export')) + + expect(screen.getByTestId('dsl-export-modal')).toHaveTextContent('API_KEY') + expect(mockHandleExportDSL).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/rag-pipeline/components/__tests__/screenshot.spec.tsx b/web/app/components/rag-pipeline/components/__tests__/screenshot.spec.tsx new file mode 100644 index 0000000000..1854b2a683 --- /dev/null +++ b/web/app/components/rag-pipeline/components/__tests__/screenshot.spec.tsx @@ -0,0 +1,29 @@ +import { render, screen } from '@testing-library/react' +import PipelineScreenShot from '../screenshot' + +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ + theme: 'dark', + }), +})) + +vi.mock('@/utils/var', () => ({ + basePath: '/console', +})) + +describe('PipelineScreenShot', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should build themed screenshot sources', () => { + const { container } = render() + const sources = container.querySelectorAll('source') + + expect(sources).toHaveLength(3) + expect(sources[0]).toHaveAttribute('srcset', '/console/screenshots/dark/Pipeline.png') + expect(sources[1]).toHaveAttribute('srcset', '/console/screenshots/dark/Pipeline@2x.png') + expect(sources[2]).toHaveAttribute('srcset', '/console/screenshots/dark/Pipeline@3x.png') + expect(screen.getByAltText('Pipeline Screenshot')).toHaveAttribute('src', '/console/screenshots/dark/Pipeline.png') + }) +}) diff --git a/web/app/components/rag-pipeline/components/chunk-card-list/__tests__/q-a-item.spec.tsx b/web/app/components/rag-pipeline/components/chunk-card-list/__tests__/q-a-item.spec.tsx new file mode 100644 index 0000000000..43dffb80f9 --- /dev/null +++ b/web/app/components/rag-pipeline/components/chunk-card-list/__tests__/q-a-item.spec.tsx @@ -0,0 +1,23 @@ +import { render, screen } from '@testing-library/react' +import QAItem from '../q-a-item' +import { QAItemType } from '../types' + +describe('QAItem', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render the question prefix', () => { + render() + + expect(screen.getByText('Q')).toBeInTheDocument() + expect(screen.getByText('What is Dify?')).toBeInTheDocument() + }) + + it('should render the answer prefix', () => { + render() + + expect(screen.getByText('A')).toBeInTheDocument() + expect(screen.getByText('An LLM app platform.')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/editor/__tests__/utils.spec.ts b/web/app/components/rag-pipeline/components/panel/input-field/editor/__tests__/utils.spec.ts new file mode 100644 index 0000000000..e4e53a4c5b --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/editor/__tests__/utils.spec.ts @@ -0,0 +1,97 @@ +import { SupportUploadFileTypes } from '@/app/components/workflow/types' +import { VAR_ITEM_TEMPLATE_IN_PIPELINE } from '@/config' +import { PipelineInputVarType } from '@/models/pipeline' +import { TransferMethod } from '@/types/app' +import { convertFormDataToINputField, convertToInputFieldFormData } from '../utils' + +describe('input-field editor utils', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should convert pipeline input vars into form data', () => { + const result = convertToInputFieldFormData({ + type: PipelineInputVarType.multiFiles, + label: 'Upload files', + variable: 'documents', + max_length: 5, + default_value: 'initial-value', + required: false, + tooltips: 'Tooltip text', + options: ['a', 'b'], + placeholder: 'Select files', + unit: 'MB', + allowed_file_upload_methods: [TransferMethod.local_file], + allowed_file_types: [SupportUploadFileTypes.document], + allowed_file_extensions: ['pdf'], + }) + + expect(result).toEqual({ + type: PipelineInputVarType.multiFiles, + label: 'Upload files', + variable: 'documents', + maxLength: 5, + default: 'initial-value', + required: false, + tooltips: 'Tooltip text', + options: ['a', 'b'], + placeholder: 'Select files', + unit: 'MB', + allowedFileUploadMethods: [TransferMethod.local_file], + allowedTypesAndExtensions: { + allowedFileTypes: [SupportUploadFileTypes.document], + allowedFileExtensions: ['pdf'], + }, + }) + }) + + it('should fall back to the default input variable template', () => { + const result = convertToInputFieldFormData() + + expect(result).toEqual({ + type: VAR_ITEM_TEMPLATE_IN_PIPELINE.type, + label: VAR_ITEM_TEMPLATE_IN_PIPELINE.label, + variable: VAR_ITEM_TEMPLATE_IN_PIPELINE.variable, + maxLength: undefined, + required: VAR_ITEM_TEMPLATE_IN_PIPELINE.required, + options: VAR_ITEM_TEMPLATE_IN_PIPELINE.options, + allowedTypesAndExtensions: {}, + }) + }) + + it('should convert form data back into pipeline input variables', () => { + const result = convertFormDataToINputField({ + type: PipelineInputVarType.select, + label: 'Category', + variable: 'category', + maxLength: 10, + default: 'books', + required: true, + tooltips: 'Pick one', + options: ['books', 'music'], + placeholder: 'Choose', + unit: '', + allowedFileUploadMethods: [TransferMethod.local_file], + allowedTypesAndExtensions: { + allowedFileTypes: [SupportUploadFileTypes.document], + allowedFileExtensions: ['txt'], + }, + }) + + expect(result).toEqual({ + type: PipelineInputVarType.select, + label: 'Category', + variable: 'category', + max_length: 10, + default_value: 'books', + required: true, + tooltips: 'Pick one', + options: ['books', 'music'], + placeholder: 'Choose', + unit: '', + allowed_file_upload_methods: [TransferMethod.local_file], + allowed_file_types: [SupportUploadFileTypes.document], + allowed_file_extensions: ['txt'], + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/editor/form/__tests__/hidden-fields.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/__tests__/hidden-fields.spec.tsx new file mode 100644 index 0000000000..0a5b748c7b --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/__tests__/hidden-fields.spec.tsx @@ -0,0 +1,73 @@ +import type { InputFieldFormProps } from '../types' +import { render, screen } from '@testing-library/react' +import { useAppForm } from '@/app/components/base/form' +import HiddenFields from '../hidden-fields' +import { useHiddenConfigurations } from '../hooks' + +const { mockInputField } = vi.hoisted(() => ({ + mockInputField: vi.fn(({ config }: { config: { variable: string } }) => { + return function FieldComponent() { + return
{config.variable}
+ } + }), +})) + +vi.mock('@/app/components/base/form/form-scenarios/input-field/field', () => ({ + default: mockInputField, +})) + +vi.mock('../hooks', () => ({ + useHiddenConfigurations: vi.fn(), +})) + +describe('HiddenFields', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should build fields from the hidden configuration list', () => { + vi.mocked(useHiddenConfigurations).mockReturnValue([ + { variable: 'default' }, + { variable: 'tooltips' }, + ] as ReturnType) + + const HiddenFieldsHarness = () => { + const initialData: InputFieldFormProps['initialData'] = { + variable: 'field_1', + options: ['option-a', 'option-b'], + } + const form = useAppForm({ + defaultValues: initialData, + onSubmit: () => {}, + }) + const HiddenFieldsComp = HiddenFields({ initialData }) + return + } + render() + + expect(useHiddenConfigurations).toHaveBeenCalledWith({ + options: ['option-a', 'option-b'], + }) + expect(mockInputField).toHaveBeenCalledTimes(2) + expect(screen.getAllByTestId('input-field')).toHaveLength(2) + expect(screen.getByText('default')).toBeInTheDocument() + expect(screen.getByText('tooltips')).toBeInTheDocument() + }) + + it('should render nothing when there are no hidden configurations', () => { + vi.mocked(useHiddenConfigurations).mockReturnValue([]) + + const HiddenFieldsHarness = () => { + const initialData: InputFieldFormProps['initialData'] = { options: [] } + const form = useAppForm({ + defaultValues: initialData, + onSubmit: () => {}, + }) + const HiddenFieldsComp = HiddenFields({ initialData }) + return + } + const { container } = render() + + expect(container).toBeEmptyDOMElement() + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/editor/form/__tests__/initial-fields.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/__tests__/initial-fields.spec.tsx new file mode 100644 index 0000000000..e6bf21ed74 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/__tests__/initial-fields.spec.tsx @@ -0,0 +1,85 @@ +import type { ComponentType } from 'react' +import { render, screen } from '@testing-library/react' +import { useConfigurations } from '../hooks' +import InitialFields from '../initial-fields' + +type MockForm = { + store: object + getFieldValue: (fieldName: string) => unknown + setFieldValue: (fieldName: string, value: unknown) => void +} + +const { + mockForm, + mockInputField, +} = vi.hoisted(() => ({ + mockForm: { + store: {}, + getFieldValue: vi.fn(), + setFieldValue: vi.fn(), + } as MockForm, + mockInputField: vi.fn(({ config }: { config: { variable: string } }) => { + return function FieldComponent() { + return
{config.variable}
+ } + }), +})) + +vi.mock('@/app/components/base/form', () => ({ + withForm: ({ render }: { + render: (props: { form: MockForm }) => React.ReactNode + }) => ({ form }: { form?: MockForm }) => render({ form: form ?? mockForm }), +})) + +vi.mock('@/app/components/base/form/form-scenarios/input-field/field', () => ({ + default: mockInputField, +})) + +vi.mock('../hooks', () => ({ + useConfigurations: vi.fn(), +})) + +describe('InitialFields', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should build initial fields with the form accessors and supportFile flag', () => { + vi.mocked(useConfigurations).mockReturnValue([ + { variable: 'type' }, + { variable: 'label' }, + ] as ReturnType) + + const InitialFieldsComp = InitialFields({ + initialData: { variable: 'field_1' }, + supportFile: true, + }) as unknown as ComponentType + render() + + expect(useConfigurations).toHaveBeenCalledWith(expect.objectContaining({ + supportFile: true, + getFieldValue: expect.any(Function), + setFieldValue: expect.any(Function), + })) + expect(screen.getAllByTestId('input-field')).toHaveLength(2) + expect(screen.getByText('type')).toBeInTheDocument() + expect(screen.getByText('label')).toBeInTheDocument() + }) + + it('should delegate field accessors to the underlying form instance', () => { + vi.mocked(useConfigurations).mockReturnValue([] as ReturnType) + mockForm.getFieldValue = vi.fn(() => 'label-value') + mockForm.setFieldValue = vi.fn() + + const InitialFieldsComp = InitialFields({ supportFile: false }) as unknown as ComponentType + render() + + const call = vi.mocked(useConfigurations).mock.calls[0]?.[0] + const value = call?.getFieldValue('label') + call?.setFieldValue('label', 'next-value') + + expect(value).toBe('label-value') + expect(mockForm.getFieldValue).toHaveBeenCalledWith('label') + expect(mockForm.setFieldValue).toHaveBeenCalledWith('label', 'next-value') + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/editor/form/__tests__/show-all-settings.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/__tests__/show-all-settings.spec.tsx new file mode 100644 index 0000000000..9dd943f969 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/__tests__/show-all-settings.spec.tsx @@ -0,0 +1,62 @@ +import type { InputFieldFormProps } from '../types' +import { fireEvent, render, screen } from '@testing-library/react' +import { useAppForm } from '@/app/components/base/form' +import { PipelineInputVarType } from '@/models/pipeline' +import { useHiddenFieldNames } from '../hooks' +import ShowAllSettings from '../show-all-settings' + +vi.mock('../hooks', () => ({ + useHiddenFieldNames: vi.fn(), +})) + +describe('ShowAllSettings', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.mocked(useHiddenFieldNames).mockReturnValue('default value, placeholder') + }) + + it('should render the summary and hidden field names', () => { + const ShowAllSettingsHarness = () => { + const initialData: InputFieldFormProps['initialData'] = { + type: PipelineInputVarType.textInput, + } + const form = useAppForm({ + defaultValues: initialData, + onSubmit: () => {}, + }) + const ShowAllSettingsComp = ShowAllSettings({ + initialData, + handleShowAllSettings: vi.fn(), + }) + return + } + render() + + expect(useHiddenFieldNames).toHaveBeenCalledWith(PipelineInputVarType.textInput) + expect(screen.getByText('appDebug.variableConfig.showAllSettings')).toBeInTheDocument() + expect(screen.getByText('default value, placeholder')).toBeInTheDocument() + }) + + it('should call the click handler when the row is pressed', () => { + const handleShowAllSettings = vi.fn() + const ShowAllSettingsHarness = () => { + const initialData: InputFieldFormProps['initialData'] = { + type: PipelineInputVarType.textInput, + } + const form = useAppForm({ + defaultValues: initialData, + onSubmit: () => {}, + }) + const ShowAllSettingsComp = ShowAllSettings({ + initialData, + handleShowAllSettings, + }) + return + } + render() + + fireEvent.click(screen.getByText('appDebug.variableConfig.showAllSettings')) + + expect(handleShowAllSettings).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/field-list/__tests__/field-item.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/field-list/__tests__/field-item.spec.tsx new file mode 100644 index 0000000000..4a738761d0 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/field-list/__tests__/field-item.spec.tsx @@ -0,0 +1,83 @@ +import type { InputVar } from '@/models/pipeline' +import { fireEvent, render, screen } from '@testing-library/react' +import { PipelineInputVarType } from '@/models/pipeline' +import FieldItem from '../field-item' + +const createInputVar = (overrides: Partial = {}): InputVar => ({ + type: PipelineInputVarType.textInput, + label: 'Field Label', + variable: 'field_name', + max_length: 48, + default_value: '', + required: true, + tooltips: '', + options: [], + placeholder: '', + unit: '', + allowed_file_upload_methods: [], + allowed_file_types: [], + allowed_file_extensions: [], + ...overrides, +}) + +describe('FieldItem', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render the variable, label, and required badge', () => { + render( + , + ) + + expect(screen.getByText('field_name')).toBeInTheDocument() + expect(screen.getByText('Field Label')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.start.required')).toBeInTheDocument() + }) + + it('should show edit and delete controls on hover and trigger both callbacks', () => { + const onClickEdit = vi.fn() + const onRemove = vi.fn() + const { container } = render( + , + ) + + fireEvent.mouseEnter(container.firstChild!) + const buttons = screen.getAllByRole('button') + fireEvent.click(buttons[0]) + fireEvent.click(buttons[1]) + + expect(onClickEdit).toHaveBeenCalledWith('custom_field') + expect(onRemove).toHaveBeenCalledWith(2) + }) + + it('should keep the row readonly when readonly is enabled', () => { + const onClickEdit = vi.fn() + const onRemove = vi.fn() + const { container } = render( + , + ) + + fireEvent.mouseEnter(container.firstChild!) + + expect(screen.queryAllByRole('button')).toHaveLength(0) + expect(onClickEdit).not.toHaveBeenCalled() + expect(onRemove).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/field-list/__tests__/field-list-container.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/field-list/__tests__/field-list-container.spec.tsx new file mode 100644 index 0000000000..5e49a4c9b4 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/field-list/__tests__/field-list-container.spec.tsx @@ -0,0 +1,60 @@ +import type { InputVar } from '@/models/pipeline' +import { fireEvent, render, screen } from '@testing-library/react' +import { PipelineInputVarType } from '@/models/pipeline' +import FieldListContainer from '../field-list-container' + +const createInputVar = (variable: string): InputVar => ({ + type: PipelineInputVarType.textInput, + label: variable, + variable, + max_length: 48, + default_value: '', + required: true, + tooltips: '', + options: [], + placeholder: '', + unit: '', + allowed_file_upload_methods: [], + allowed_file_types: [], + allowed_file_extensions: [], +}) + +describe('FieldListContainer', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render the field items inside the sortable container', () => { + const onListSortChange = vi.fn() + const { container } = render( + , + ) + + expect(screen.getAllByText('field_1').length).toBeGreaterThan(0) + expect(screen.getAllByText('field_2').length).toBeGreaterThan(0) + expect(container.querySelector('.handle')).toBeInTheDocument() + expect(onListSortChange).not.toHaveBeenCalled() + }) + + it('should honor readonly mode for the rendered field rows', () => { + const { container } = render( + , + ) + + const firstRow = container.querySelector('.handle') + fireEvent.mouseEnter(firstRow!) + + expect(screen.queryAllByRole('button')).toHaveLength(0) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/__tests__/datasource.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/__tests__/datasource.spec.tsx new file mode 100644 index 0000000000..b0ab5d5312 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/__tests__/datasource.spec.tsx @@ -0,0 +1,24 @@ +import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types' +import { render, screen } from '@testing-library/react' +import Datasource from '../datasource' + +vi.mock('@/app/components/workflow/hooks', () => ({ + useToolIcon: () => 'tool-icon', +})) + +vi.mock('@/app/components/workflow/block-icon', () => ({ + default: ({ toolIcon }: { toolIcon: string }) =>
{toolIcon}
, +})) + +describe('Datasource', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render the datasource title and icon', () => { + render() + + expect(screen.getByTestId('block-icon')).toHaveTextContent('tool-icon') + expect(screen.getByText('Knowledge Base')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/__tests__/global-inputs.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/__tests__/global-inputs.spec.tsx new file mode 100644 index 0000000000..602a8a4708 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/__tests__/global-inputs.spec.tsx @@ -0,0 +1,23 @@ +import { render, screen } from '@testing-library/react' +import GlobalInputs from '../global-inputs' + +vi.mock('@/app/components/base/tooltip', () => ({ + default: ({ + popupContent, + }: { + popupContent: React.ReactNode + }) =>
{popupContent}
, +})) + +describe('GlobalInputs', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render the title and tooltip copy', () => { + render() + + expect(screen.getByText('datasetPipeline.inputFieldPanel.globalInputs.title')).toBeInTheDocument() + expect(screen.getByTestId('tooltip')).toHaveTextContent('datasetPipeline.inputFieldPanel.globalInputs.tooltip') + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/preview/__tests__/data-source.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/preview/__tests__/data-source.spec.tsx new file mode 100644 index 0000000000..04701aeba4 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/preview/__tests__/data-source.spec.tsx @@ -0,0 +1,73 @@ +import type { Datasource } from '../../../test-run/types' +import { fireEvent, render, screen } from '@testing-library/react' +import DataSource from '../data-source' + +const { + mockOnSelect, + mockUseDraftPipelinePreProcessingParams, +} = vi.hoisted(() => ({ + mockOnSelect: vi.fn(), + mockUseDraftPipelinePreProcessingParams: vi.fn(() => ({ + data: { + variables: [{ variable: 'source' }], + }, + })), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { pipelineId: string }) => string) => selector({ pipelineId: 'pipeline-1' }), +})) + +vi.mock('@/service/use-pipeline', () => ({ + useDraftPipelinePreProcessingParams: mockUseDraftPipelinePreProcessingParams, +})) + +vi.mock('../../../test-run/preparation/data-source-options', () => ({ + default: ({ + onSelect, + dataSourceNodeId, + }: { + onSelect: (data: Datasource) => void + dataSourceNodeId: string + }) => ( +
+ +
+ ), +})) + +vi.mock('../form', () => ({ + default: ({ variables }: { variables: Array<{ variable: string }> }) => ( +
{variables.map(item => item.variable).join(',')}
+ ), +})) + +describe('DataSource preview', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render the datasource selection step and forward selected values', () => { + render( + , + ) + + fireEvent.click(screen.getByText('select datasource')) + + expect(screen.getByText('datasetPipeline.inputFieldPanel.preview.stepOneTitle')).toBeInTheDocument() + expect(screen.getByTestId('data-source-options')).toHaveAttribute('data-node-id', 'node-1') + expect(screen.getByTestId('preview-form')).toHaveTextContent('source') + expect(mockUseDraftPipelinePreProcessingParams).toHaveBeenCalledWith({ + pipeline_id: 'pipeline-1', + node_id: 'node-1', + }, true) + expect(mockOnSelect).toHaveBeenCalledWith({ nodeId: 'source-node' }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/preview/__tests__/form.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/preview/__tests__/form.spec.tsx new file mode 100644 index 0000000000..66299e112f --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/preview/__tests__/form.spec.tsx @@ -0,0 +1,64 @@ +import type { RAGPipelineVariables } from '@/models/pipeline' +import { render, screen } from '@testing-library/react' +import Form from '../form' + +type MockForm = { + id: string +} + +const { + mockForm, + mockBaseField, + mockUseInitialData, + mockUseConfigurations, +} = vi.hoisted(() => ({ + mockForm: { + id: 'form-1', + } as MockForm, + mockBaseField: vi.fn(({ config }: { config: { variable: string } }) => { + return function FieldComponent() { + return
{config.variable}
+ } + }), + mockUseInitialData: vi.fn(() => ({ source: 'node-1' })), + mockUseConfigurations: vi.fn(() => [{ variable: 'source' }, { variable: 'chunkSize' }]), +})) + +vi.mock('@/app/components/base/form', () => ({ + useAppForm: () => mockForm, +})) + +vi.mock('@/app/components/base/form/form-scenarios/base/field', () => ({ + default: mockBaseField, +})) + +vi.mock('@/app/components/rag-pipeline/hooks/use-input-fields', () => ({ + useInitialData: mockUseInitialData, + useConfigurations: mockUseConfigurations, +})) + +describe('Preview form', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should build fields from the pipeline variable configuration', () => { + render(
) + + expect(mockUseInitialData).toHaveBeenCalled() + expect(mockUseConfigurations).toHaveBeenCalled() + expect(screen.getAllByTestId('base-field')).toHaveLength(2) + expect(screen.getByText('source')).toBeInTheDocument() + expect(screen.getByText('chunkSize')).toBeInTheDocument() + }) + + it('should prevent the native form submission', () => { + const { container } = render() + const form = container.querySelector('form')! + const submitEvent = new Event('submit', { bubbles: true, cancelable: true }) + + form.dispatchEvent(submitEvent) + + expect(submitEvent.defaultPrevented).toBe(true) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/preview/__tests__/process-documents.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/preview/__tests__/process-documents.spec.tsx new file mode 100644 index 0000000000..3e4944d775 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/preview/__tests__/process-documents.spec.tsx @@ -0,0 +1,39 @@ +import { render, screen } from '@testing-library/react' +import ProcessDocuments from '../process-documents' + +const mockUseDraftPipelineProcessingParams = vi.hoisted(() => vi.fn(() => ({ + data: { + variables: [{ variable: 'chunkSize' }], + }, +}))) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { pipelineId: string }) => string) => selector({ pipelineId: 'pipeline-1' }), +})) + +vi.mock('@/service/use-pipeline', () => ({ + useDraftPipelineProcessingParams: mockUseDraftPipelineProcessingParams, +})) + +vi.mock('../form', () => ({ + default: ({ variables }: { variables: Array<{ variable: string }> }) => ( +
{variables.map(item => item.variable).join(',')}
+ ), +})) + +describe('ProcessDocuments preview', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render the processing step and its variables', () => { + render() + + expect(screen.getByText('datasetPipeline.inputFieldPanel.preview.stepTwoTitle')).toBeInTheDocument() + expect(screen.getByTestId('preview-form')).toHaveTextContent('chunkSize') + expect(mockUseDraftPipelineProcessingParams).toHaveBeenCalledWith({ + pipeline_id: 'pipeline-1', + node_id: 'node-2', + }, true) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/test-run/__tests__/header.spec.tsx b/web/app/components/rag-pipeline/components/panel/test-run/__tests__/header.spec.tsx new file mode 100644 index 0000000000..8149bac144 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/test-run/__tests__/header.spec.tsx @@ -0,0 +1,60 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import Header from '../header' + +const { + mockSetIsPreparingDataSource, + mockHandleCancelDebugAndPreviewPanel, + mockWorkflowStore, +} = vi.hoisted(() => ({ + mockSetIsPreparingDataSource: vi.fn(), + mockHandleCancelDebugAndPreviewPanel: vi.fn(), + mockWorkflowStore: { + getState: vi.fn(() => ({ + isPreparingDataSource: true, + setIsPreparingDataSource: vi.fn(), + })), + }, +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => mockWorkflowStore, +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useWorkflowInteractions: () => ({ + handleCancelDebugAndPreviewPanel: mockHandleCancelDebugAndPreviewPanel, + }), +})) + +describe('TestRun header', () => { + beforeEach(() => { + vi.clearAllMocks() + mockWorkflowStore.getState.mockReturnValue({ + isPreparingDataSource: true, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + }) + }) + + it('should render the title and reset preparing state on close', () => { + render(
) + + fireEvent.click(screen.getByRole('button')) + + expect(screen.getByText('datasetPipeline.testRun.title')).toBeInTheDocument() + expect(mockSetIsPreparingDataSource).toHaveBeenCalledWith(false) + expect(mockHandleCancelDebugAndPreviewPanel).toHaveBeenCalledTimes(1) + }) + + it('should only cancel the panel when the datasource preparation flag is false', () => { + mockWorkflowStore.getState.mockReturnValue({ + isPreparingDataSource: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + }) + + render(
) + fireEvent.click(screen.getByRole('button')) + + expect(mockSetIsPreparingDataSource).not.toHaveBeenCalled() + expect(mockHandleCancelDebugAndPreviewPanel).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/test-run/preparation/__tests__/footer-tips.spec.tsx b/web/app/components/rag-pipeline/components/panel/test-run/preparation/__tests__/footer-tips.spec.tsx new file mode 100644 index 0000000000..b4eab3fe72 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/test-run/preparation/__tests__/footer-tips.spec.tsx @@ -0,0 +1,14 @@ +import { render, screen } from '@testing-library/react' +import FooterTips from '../footer-tips' + +describe('FooterTips', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render the localized footer copy', () => { + render() + + expect(screen.getByText('datasetPipeline.testRun.tooltip')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/test-run/preparation/__tests__/step-indicator.spec.tsx b/web/app/components/rag-pipeline/components/panel/test-run/preparation/__tests__/step-indicator.spec.tsx new file mode 100644 index 0000000000..d5985f2969 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/test-run/preparation/__tests__/step-indicator.spec.tsx @@ -0,0 +1,41 @@ +import { render, screen } from '@testing-library/react' +import StepIndicator from '../step-indicator' + +describe('StepIndicator', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render all step labels and highlight the current step', () => { + const { container } = render( + , + ) + + expect(screen.getByText('Select source')).toBeInTheDocument() + expect(screen.getByText('Process docs')).toBeInTheDocument() + expect(screen.getByText('Run test')).toBeInTheDocument() + expect(container.querySelector('.bg-state-accent-solid')).toBeInTheDocument() + expect(screen.getByText('Process docs').parentElement).toHaveClass('text-state-accent-solid') + }) + + it('should keep inactive steps in the tertiary state', () => { + render( + , + ) + + expect(screen.getByText('Process docs').parentElement).toHaveClass('text-text-tertiary') + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/test-run/preparation/data-source-options/__tests__/option-card.spec.tsx b/web/app/components/rag-pipeline/components/panel/test-run/preparation/data-source-options/__tests__/option-card.spec.tsx new file mode 100644 index 0000000000..83cb252943 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/test-run/preparation/data-source-options/__tests__/option-card.spec.tsx @@ -0,0 +1,49 @@ +import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types' +import { fireEvent, render, screen } from '@testing-library/react' +import OptionCard from '../option-card' + +vi.mock('@/app/components/workflow/hooks', () => ({ + useToolIcon: () => 'source-icon', +})) + +vi.mock('@/app/components/workflow/block-icon', () => ({ + default: ({ toolIcon }: { toolIcon: string }) =>
{toolIcon}
, +})) + +describe('OptionCard', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render the datasource label and icon', () => { + render( + , + ) + + expect(screen.getByTestId('block-icon')).toHaveTextContent('source-icon') + expect(screen.getByText('Website Crawl')).toBeInTheDocument() + }) + + it('should call onClick with the card value and apply selected styles', () => { + const onClick = vi.fn() + render( + , + ) + + fireEvent.click(screen.getByText('Online Drive')) + + expect(onClick).toHaveBeenCalledWith('online-drive') + expect(screen.getByText('Online Drive')).toHaveClass('text-text-primary') + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/__tests__/actions.spec.tsx b/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/__tests__/actions.spec.tsx new file mode 100644 index 0000000000..69f576eae7 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/__tests__/actions.spec.tsx @@ -0,0 +1,67 @@ +import type { CustomActionsProps } from '@/app/components/base/form/components/form/actions' +import { fireEvent, render, screen } from '@testing-library/react' +import { WorkflowRunningStatus } from '@/app/components/workflow/types' +import Actions from '../actions' + +let mockWorkflowRunningData: { result: { status: WorkflowRunningStatus } } | undefined + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { workflowRunningData: typeof mockWorkflowRunningData }) => unknown) => selector({ + workflowRunningData: mockWorkflowRunningData, + }), +})) + +const createFormParams = (overrides: Partial = {}): CustomActionsProps => ({ + form: { + handleSubmit: vi.fn(), + } as unknown as CustomActionsProps['form'], + isSubmitting: false, + canSubmit: true, + ...overrides, +}) + +describe('Document processing actions', () => { + beforeEach(() => { + vi.clearAllMocks() + mockWorkflowRunningData = undefined + }) + + it('should render back/process actions and trigger both callbacks', () => { + const onBack = vi.fn() + const formParams = createFormParams() + + render() + + fireEvent.click(screen.getByRole('button', { name: 'datasetPipeline.operations.backToDataSource' })) + fireEvent.click(screen.getByRole('button', { name: 'datasetPipeline.operations.process' })) + + expect(onBack).toHaveBeenCalledTimes(1) + expect(formParams.form.handleSubmit).toHaveBeenCalledTimes(1) + }) + + it('should disable processing when runDisabled or the workflow is already running', () => { + const { rerender } = render( + , + ) + + expect(screen.getByRole('button', { name: 'datasetPipeline.operations.process' })).toBeDisabled() + + mockWorkflowRunningData = { + result: { + status: WorkflowRunningStatus.Running, + }, + } + rerender( + , + ) + + expect(screen.getByRole('button', { name: /datasetPipeline\.operations\.process/i })).toBeDisabled() + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/__tests__/hooks.spec.ts b/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/__tests__/hooks.spec.ts new file mode 100644 index 0000000000..822d553732 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/__tests__/hooks.spec.ts @@ -0,0 +1,32 @@ +import { renderHook } from '@testing-library/react' +import { useInputVariables } from '../hooks' + +const mockUseDraftPipelineProcessingParams = vi.hoisted(() => vi.fn(() => ({ + data: { variables: [{ variable: 'chunkSize' }] }, + isFetching: true, +}))) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { pipelineId: string }) => string) => selector({ pipelineId: 'pipeline-1' }), +})) + +vi.mock('@/service/use-pipeline', () => ({ + useDraftPipelineProcessingParams: mockUseDraftPipelineProcessingParams, +})) + +describe('useInputVariables', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should query processing params with the current pipeline id and datasource node id', () => { + const { result } = renderHook(() => useInputVariables('datasource-node')) + + expect(mockUseDraftPipelineProcessingParams).toHaveBeenCalledWith({ + pipeline_id: 'pipeline-1', + node_id: 'datasource-node', + }) + expect(result.current.isFetchingParams).toBe(true) + expect(result.current.paramsConfig).toEqual({ variables: [{ variable: 'chunkSize' }] }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/__tests__/options.spec.tsx b/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/__tests__/options.spec.tsx new file mode 100644 index 0000000000..fcfa305bb3 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/__tests__/options.spec.tsx @@ -0,0 +1,140 @@ +import type { ZodSchema } from 'zod' +import type { CustomActionsProps } from '@/app/components/base/form/components/form/actions' +import type { BaseConfiguration } from '@/app/components/base/form/form-scenarios/base/types' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import Options from '../options' + +const { + mockFormValue, + mockHandleSubmit, + mockToastError, + mockBaseField, +} = vi.hoisted(() => ({ + mockFormValue: { chunkSize: 256 } as Record, + mockHandleSubmit: vi.fn(), + mockToastError: vi.fn(), + mockBaseField: vi.fn(({ config }: { config: { variable: string } }) => { + return function FieldComponent() { + return
{config.variable}
+ } + }), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, + }, +})) + +vi.mock('@/app/components/base/form/form-scenarios/base/field', () => ({ + default: mockBaseField, +})) + +vi.mock('@/app/components/base/form', () => ({ + useAppForm: ({ + onSubmit, + validators, + }: { + onSubmit: (params: { value: Record }) => void + validators?: { + onSubmit?: (params: { value: Record }) => string | undefined + } + }) => ({ + handleSubmit: () => { + const validationResult = validators?.onSubmit?.({ value: mockFormValue }) + if (!validationResult) + onSubmit({ value: mockFormValue }) + mockHandleSubmit() + }, + AppForm: ({ children }: { children: React.ReactNode }) =>
{children}
, + Actions: ({ CustomActions }: { CustomActions: (props: CustomActionsProps) => React.ReactNode }) => ( +
+ {CustomActions({ + form: { + handleSubmit: mockHandleSubmit, + } as unknown as CustomActionsProps['form'], + isSubmitting: false, + canSubmit: true, + })} +
+ ), + }), +})) + +const createSchema = (success: boolean): ZodSchema => ({ + safeParse: vi.fn(() => { + if (success) + return { success: true } + + return { + success: false, + error: { + issues: [{ + path: ['chunkSize'], + message: 'Invalid value', + }], + }, + } + }), +}) as unknown as ZodSchema + +describe('Document processing options', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render base fields and the custom actions slot', () => { + render( +
custom actions
} + onSubmit={vi.fn()} + />, + ) + + expect(screen.getByTestId('base-field')).toHaveTextContent('chunkSize') + expect(screen.getByTestId('form-actions')).toBeInTheDocument() + expect(screen.getByTestId('custom-actions')).toBeInTheDocument() + }) + + it('should validate and toast the first schema error before submitting', async () => { + const onSubmit = vi.fn() + const { container } = render( +
actions
} + onSubmit={onSubmit} + />, + ) + + fireEvent.submit(container.querySelector('form')!) + + await waitFor(() => { + expect(mockToastError).toHaveBeenCalledWith('Path: chunkSize Error: Invalid value') + }) + expect(onSubmit).not.toHaveBeenCalled() + }) + + it('should submit the parsed form value when validation succeeds', async () => { + const onSubmit = vi.fn() + const { container } = render( +
actions
} + onSubmit={onSubmit} + />, + ) + + fireEvent.submit(container.querySelector('form')!) + + await waitFor(() => { + expect(onSubmit).toHaveBeenCalledWith(mockFormValue) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/test-run/result/result-preview/__tests__/utils.spec.ts b/web/app/components/rag-pipeline/components/panel/test-run/result/result-preview/__tests__/utils.spec.ts new file mode 100644 index 0000000000..376b529d40 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/test-run/result/result-preview/__tests__/utils.spec.ts @@ -0,0 +1,84 @@ +import { ChunkingMode } from '@/models/datasets' +import { formatPreviewChunks } from '../utils' + +vi.mock('@/config', () => ({ + RAG_PIPELINE_PREVIEW_CHUNK_NUM: 2, +})) + +describe('result preview utils', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should return undefined for empty outputs', () => { + expect(formatPreviewChunks(undefined)).toBeUndefined() + expect(formatPreviewChunks(null)).toBeUndefined() + }) + + it('should format text chunks and limit them to the preview length', () => { + const result = formatPreviewChunks({ + chunk_structure: ChunkingMode.text, + preview: [ + { content: 'Chunk 1', summary: 'S1' }, + { content: 'Chunk 2', summary: 'S2' }, + { content: 'Chunk 3', summary: 'S3' }, + ], + }) + + expect(result).toEqual([ + { content: 'Chunk 1', summary: 'S1' }, + { content: 'Chunk 2', summary: 'S2' }, + ]) + }) + + it('should format paragraph and full-doc parent-child previews differently', () => { + const paragraph = formatPreviewChunks({ + chunk_structure: ChunkingMode.parentChild, + parent_mode: 'paragraph', + preview: [ + { content: 'Parent 1', child_chunks: ['c1', 'c2', 'c3'] }, + { content: 'Parent 2', child_chunks: ['c4'] }, + { content: 'Parent 3', child_chunks: ['c5'] }, + ], + }) + const fullDoc = formatPreviewChunks({ + chunk_structure: ChunkingMode.parentChild, + parent_mode: 'full-doc', + preview: [ + { content: 'Parent 1', child_chunks: ['c1', 'c2', 'c3'] }, + ], + }) + + expect(paragraph).toEqual({ + parent_mode: 'paragraph', + parent_child_chunks: [ + { parent_content: 'Parent 1', parent_summary: undefined, child_contents: ['c1', 'c2', 'c3'], parent_mode: 'paragraph' }, + { parent_content: 'Parent 2', parent_summary: undefined, child_contents: ['c4'], parent_mode: 'paragraph' }, + ], + }) + expect(fullDoc).toEqual({ + parent_mode: 'full-doc', + parent_child_chunks: [ + { parent_content: 'Parent 1', child_contents: ['c1', 'c2'], parent_mode: 'full-doc' }, + ], + }) + }) + + it('should format qa previews and limit them to the preview size', () => { + const result = formatPreviewChunks({ + chunk_structure: ChunkingMode.qa, + qa_preview: [ + { question: 'Q1', answer: 'A1' }, + { question: 'Q2', answer: 'A2' }, + { question: 'Q3', answer: 'A3' }, + ], + }) + + expect(result).toEqual({ + qa_chunks: [ + { question: 'Q1', answer: 'A1' }, + { question: 'Q2', answer: 'A2' }, + ], + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/test-run/result/tabs/__tests__/tab.spec.tsx b/web/app/components/rag-pipeline/components/panel/test-run/result/tabs/__tests__/tab.spec.tsx new file mode 100644 index 0000000000..0597bc3de8 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/test-run/result/tabs/__tests__/tab.spec.tsx @@ -0,0 +1,64 @@ +import type { WorkflowRunningData } from '@/app/components/workflow/types' +import { fireEvent, render, screen } from '@testing-library/react' +import Tab from '../tab' + +const createWorkflowRunningData = (): WorkflowRunningData => ({ + task_id: 'task-1', + message_id: 'message-1', + conversation_id: 'conversation-1', + result: { + workflow_id: 'workflow-1', + inputs: '{}', + inputs_truncated: false, + process_data: '{}', + process_data_truncated: false, + outputs: '{}', + outputs_truncated: false, + status: 'succeeded', + elapsed_time: 10, + total_tokens: 20, + created_at: Date.now(), + finished_at: Date.now(), + steps: 1, + total_steps: 1, + }, + tracing: [], +}) + +describe('Tab', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render an active tab and pass its value on click', () => { + const onClick = vi.fn() + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'Preview' })) + + expect(screen.getByRole('button')).toHaveClass('border-util-colors-blue-brand-blue-brand-600') + expect(onClick).toHaveBeenCalledWith('preview') + }) + + it('should disable the tab when workflow run data is unavailable', () => { + render( + , + ) + + expect(screen.getByRole('button', { name: 'Trace' })).toBeDisabled() + expect(screen.getByRole('button', { name: 'Trace' })).toHaveClass('opacity-30') + }) +}) diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/__tests__/input-field-button.spec.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/__tests__/input-field-button.spec.tsx new file mode 100644 index 0000000000..493f3c3014 --- /dev/null +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/__tests__/input-field-button.spec.tsx @@ -0,0 +1,35 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import InputFieldButton from '../input-field-button' + +const { + mockSetShowInputFieldPanel, + mockSetShowEnvPanel, +} = vi.hoisted(() => ({ + mockSetShowInputFieldPanel: vi.fn(), + mockSetShowEnvPanel: vi.fn(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { + setShowInputFieldPanel: typeof mockSetShowInputFieldPanel + setShowEnvPanel: typeof mockSetShowEnvPanel + }) => unknown) => selector({ + setShowInputFieldPanel: mockSetShowInputFieldPanel, + setShowEnvPanel: mockSetShowEnvPanel, + }), +})) + +describe('InputFieldButton', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should open the input field panel and close the env panel', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'datasetPipeline.inputField' })) + + expect(mockSetShowInputFieldPanel).toHaveBeenCalledWith(true) + expect(mockSetShowEnvPanel).toHaveBeenCalledWith(false) + }) +}) diff --git a/web/app/components/rag-pipeline/utils/__tests__/nodes.spec.ts b/web/app/components/rag-pipeline/utils/__tests__/nodes.spec.ts new file mode 100644 index 0000000000..c90e702d8e --- /dev/null +++ b/web/app/components/rag-pipeline/utils/__tests__/nodes.spec.ts @@ -0,0 +1,92 @@ +import type { Viewport } from 'reactflow' +import type { Node } from '@/app/components/workflow/types' +import { BlockEnum } from '@/app/components/workflow/types' +import { processNodesWithoutDataSource } from '../nodes' + +vi.mock('@/app/components/workflow/constants', () => ({ + CUSTOM_NODE: 'custom', + NODE_WIDTH_X_OFFSET: 400, + START_INITIAL_POSITION: { x: 100, y: 100 }, +})) + +vi.mock('@/app/components/workflow/nodes/data-source-empty/constants', () => ({ + CUSTOM_DATA_SOURCE_EMPTY_NODE: 'data-source-empty', +})) + +vi.mock('@/app/components/workflow/note-node/constants', () => ({ + CUSTOM_NOTE_NODE: 'note', +})) + +vi.mock('@/app/components/workflow/note-node/types', () => ({ + NoteTheme: { blue: 'blue' }, +})) + +vi.mock('@/app/components/workflow/utils', () => ({ + generateNewNode: ({ id, type, data, position }: { id: string, type: string, data: object, position: { x: number, y: number } }) => ({ + newNode: { id, type, data, position }, + }), +})) + +describe('processNodesWithoutDataSource', () => { + it('should return the original nodes when a datasource node already exists', () => { + const nodes = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.DataSource }, + position: { x: 100, y: 100 }, + }, + ] as Node[] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + expect(result.nodes).toBe(nodes) + expect(result.viewport).toBe(viewport) + }) + + it('should prepend datasource empty and note nodes when the pipeline starts without a datasource', () => { + const nodes = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase }, + position: { x: 300, y: 200 }, + }, + ] as Node[] + + const result = processNodesWithoutDataSource(nodes, { x: 0, y: 0, zoom: 2 }) + + expect(result.nodes[0]).toEqual(expect.objectContaining({ + id: 'data-source-empty', + type: 'data-source-empty', + position: { x: -100, y: 200 }, + })) + expect(result.nodes[1]).toEqual(expect.objectContaining({ + id: 'note', + type: 'note', + position: { x: -100, y: 300 }, + })) + expect(result.viewport).toEqual({ + x: 400, + y: -200, + zoom: 2, + }) + }) + + it('should leave nodes unchanged when there is no custom node to anchor from', () => { + const nodes = [ + { + id: 'node-1', + type: 'note', + data: { type: BlockEnum.Answer }, + position: { x: 100, y: 100 }, + }, + ] as Node[] + + const result = processNodesWithoutDataSource(nodes) + + expect(result.nodes).toBe(nodes) + expect(result.viewport).toBeUndefined() + }) +}) diff --git a/web/app/components/tools/edit-custom-collection-modal/__tests__/examples.spec.ts b/web/app/components/tools/edit-custom-collection-modal/__tests__/examples.spec.ts new file mode 100644 index 0000000000..6fe3576c26 --- /dev/null +++ b/web/app/components/tools/edit-custom-collection-modal/__tests__/examples.spec.ts @@ -0,0 +1,18 @@ +import { describe, expect, it } from 'vitest' +import examples from '../examples' + +describe('edit-custom-collection examples', () => { + it('provides json, yaml, and blank templates in fixed order', () => { + expect(examples.map(example => example.key)).toEqual([ + 'json', + 'yaml', + 'blankTemplate', + ]) + }) + + it('contains representative OpenAPI content for each template', () => { + expect(examples[0].content).toContain('"openapi": "3.1.0"') + expect(examples[1].content).toContain('openapi: "3.0.0"') + expect(examples[2].content).toContain('"title": "Untitled"') + }) +}) diff --git a/web/app/components/tools/labels/__tests__/constant.spec.ts b/web/app/components/tools/labels/__tests__/constant.spec.ts new file mode 100644 index 0000000000..614476fb8c --- /dev/null +++ b/web/app/components/tools/labels/__tests__/constant.spec.ts @@ -0,0 +1,33 @@ +import type { Label } from '../constant' +import { describe, expect, it } from 'vitest' + +describe('tool label type contract', () => { + it('accepts string labels', () => { + const label: Label = { + name: 'agent', + label: 'Agent', + icon: 'robot', + } + + expect(label).toEqual({ + name: 'agent', + label: 'Agent', + icon: 'robot', + }) + }) + + it('accepts i18n labels', () => { + const label: Label = { + name: 'workflow', + label: { + en_US: 'Workflow', + zh_Hans: '工作流', + }, + } + + expect(label.label).toEqual({ + en_US: 'Workflow', + zh_Hans: '工作流', + }) + }) +}) diff --git a/web/app/components/tools/workflow-tool/__tests__/helpers.spec.ts b/web/app/components/tools/workflow-tool/__tests__/helpers.spec.ts new file mode 100644 index 0000000000..acf8aafdf8 --- /dev/null +++ b/web/app/components/tools/workflow-tool/__tests__/helpers.spec.ts @@ -0,0 +1,102 @@ +import type { TFunction } from 'i18next' +import { describe, expect, it } from 'vitest' +import { VarType } from '@/app/components/workflow/types' +import { + buildWorkflowToolRequestPayload, + getReservedWorkflowOutputParameters, + getWorkflowOutputParameters, + hasReservedWorkflowOutputConflict, + isWorkflowToolNameValid, + RESERVED_WORKFLOW_OUTPUTS, +} from '../helpers' + +describe('workflow-tool helpers', () => { + it('validates workflow tool names', () => { + expect(isWorkflowToolNameValid('')).toBe(true) + expect(isWorkflowToolNameValid('workflow_tool_1')).toBe(true) + expect(isWorkflowToolNameValid('workflow-tool')).toBe(false) + expect(isWorkflowToolNameValid('workflow tool')).toBe(false) + }) + + it('builds translated reserved workflow outputs', () => { + const t = ((key: string, options?: { ns?: string }) => `${options?.ns}:${key}`) as TFunction + + expect(getReservedWorkflowOutputParameters(t)).toEqual([ + { + ...RESERVED_WORKFLOW_OUTPUTS[0], + description: 'workflow:nodes.tool.outputVars.text', + }, + { + ...RESERVED_WORKFLOW_OUTPUTS[1], + description: 'workflow:nodes.tool.outputVars.files.title', + }, + { + ...RESERVED_WORKFLOW_OUTPUTS[2], + description: 'workflow:nodes.tool.outputVars.json', + }, + ]) + }) + + it('detects reserved output conflicts', () => { + expect(hasReservedWorkflowOutputConflict(RESERVED_WORKFLOW_OUTPUTS, 'text')).toBe(true) + expect(hasReservedWorkflowOutputConflict(RESERVED_WORKFLOW_OUTPUTS, 'custom')).toBe(false) + }) + + it('derives workflow output parameters from schema through helper wrapper', () => { + expect(getWorkflowOutputParameters([], { + type: 'object', + properties: { + text: { + type: VarType.string, + description: 'Result text', + }, + }, + })).toEqual([ + { + name: 'text', + description: 'Result text', + type: VarType.string, + }, + ]) + }) + + it('builds workflow tool request payload', () => { + expect(buildWorkflowToolRequestPayload({ + name: 'workflow_tool', + description: 'Workflow tool', + emoji: { + content: '🧠', + background: '#ffffff', + }, + label: 'Workflow Tool', + labels: ['agent', 'workflow'], + parameters: [ + { + name: 'question', + type: VarType.string, + required: true, + form: 'llm', + description: 'Question to ask', + }, + ], + privacyPolicy: 'https://example.com/privacy', + })).toEqual({ + name: 'workflow_tool', + description: 'Workflow tool', + icon: { + content: '🧠', + background: '#ffffff', + }, + label: 'Workflow Tool', + labels: ['agent', 'workflow'], + parameters: [ + { + name: 'question', + description: 'Question to ask', + form: 'llm', + }, + ], + privacy_policy: 'https://example.com/privacy', + }) + }) +}) diff --git a/web/app/components/tools/workflow-tool/__tests__/index.spec.tsx b/web/app/components/tools/workflow-tool/__tests__/index.spec.tsx new file mode 100644 index 0000000000..f3f229abea --- /dev/null +++ b/web/app/components/tools/workflow-tool/__tests__/index.spec.tsx @@ -0,0 +1,200 @@ +import type { WorkflowToolModalPayload } from '../index' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import WorkflowToolAsModal from '../index' + +vi.mock('@/app/components/base/drawer-plus', () => ({ + default: ({ isShow, onHide, title, body }: { isShow: boolean, onHide: () => void, title: string, body: React.ReactNode }) => ( + isShow + ? ( +
+ {title} + + {body} +
+ ) + : null + ), +})) + +vi.mock('@/app/components/base/emoji-picker', () => ({ + default: ({ onSelect, onClose }: { onSelect: (icon: string, background: string) => void, onClose: () => void }) => ( +
+ + +
+ ), +})) + +vi.mock('@/app/components/base/app-icon', () => ({ + default: ({ onClick, icon }: { onClick?: () => void, icon: string }) => ( + + ), +})) + +vi.mock('@/app/components/tools/labels/selector', () => ({ + default: ({ value, onChange }: { value: string[], onChange: (labels: string[]) => void }) => ( +
+ {value.join(',')} + +
+ ), +})) + +vi.mock('@/app/components/base/tooltip', () => ({ + default: ({ + children, + popupContent, + }: { + children?: React.ReactNode + popupContent?: React.ReactNode + }) => ( +
+ {children} + {popupContent} +
+ ), +})) + +vi.mock('../confirm-modal', () => ({ + default: ({ show, onClose, onConfirm }: { show: boolean, onClose: () => void, onConfirm: () => void }) => ( + show + ? ( +
+ + +
+ ) + : null + ), +})) + +const mockToastNotify = vi.fn() +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + }, +})) + +vi.mock('@/app/components/plugins/hooks', () => ({ + useTags: () => ({ + tags: [ + { name: 'label1', label: 'Label 1' }, + { name: 'label2', label: 'Label 2' }, + ], + }), +})) + +const createPayload = (overrides: Partial = {}): WorkflowToolModalPayload => ({ + icon: { content: '🔧', background: '#ffffff' }, + label: 'My Tool', + name: 'my_tool', + description: 'Tool description', + parameters: [ + { name: 'param1', description: 'Parameter 1', form: 'llm', required: true, type: 'string' }, + ], + outputParameters: [ + { name: 'output1', description: 'Output 1' }, + { name: 'text', description: 'Reserved output duplicate' }, + ], + labels: ['label1'], + privacy_policy: '', + workflow_app_id: 'workflow-app-1', + workflow_tool_id: 'workflow-tool-1', + ...overrides, +}) + +describe('WorkflowToolAsModal', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should create workflow tools with edited form values', async () => { + const user = userEvent.setup() + const onCreate = vi.fn() + + render( + , + ) + + await user.clear(screen.getByPlaceholderText('tools.createTool.toolNamePlaceHolder')) + await user.type(screen.getByPlaceholderText('tools.createTool.toolNamePlaceHolder'), 'Created Tool') + await user.click(screen.getByTestId('append-label')) + await user.click(screen.getByTestId('app-icon')) + await user.click(screen.getByTestId('select-emoji')) + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + expect(onCreate).toHaveBeenCalledWith(expect.objectContaining({ + workflow_app_id: 'workflow-app-1', + label: 'Created Tool', + icon: { content: '🚀', background: '#000000' }, + labels: ['label1', 'new-label'], + })) + }) + + it('should block invalid tool-call names before saving', async () => { + const user = userEvent.setup() + const onCreate = vi.fn() + + render( + , + ) + + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + expect(onCreate).not.toHaveBeenCalled() + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + + it('should require confirmation before saving existing workflow tools', async () => { + const user = userEvent.setup() + const onSave = vi.fn() + + render( + , + ) + + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + expect(screen.getByTestId('confirm-modal')).toBeInTheDocument() + + await user.click(screen.getByTestId('confirm-save')) + + await waitFor(() => { + expect(onSave).toHaveBeenCalledWith(expect.objectContaining({ + workflow_tool_id: 'workflow-tool-1', + name: 'my_tool', + })) + }) + }) + + it('should show duplicate reserved output warnings', () => { + render( + , + ) + + expect(screen.getAllByText('tools.createTool.toolOutput.reservedParameterDuplicateTip').length).toBeGreaterThan(0) + }) +}) diff --git a/web/app/components/tools/workflow-tool/helpers.ts b/web/app/components/tools/workflow-tool/helpers.ts new file mode 100644 index 0000000000..9af1107c80 --- /dev/null +++ b/web/app/components/tools/workflow-tool/helpers.ts @@ -0,0 +1,95 @@ +import type { TFunction } from 'i18next' +import type { + Emoji, + WorkflowToolProviderOutputParameter, + WorkflowToolProviderOutputSchema, + WorkflowToolProviderParameter, + WorkflowToolProviderRequest, +} from '../types' +import { VarType } from '@/app/components/workflow/types' +import { buildWorkflowOutputParameters } from './utils' + +export const RESERVED_WORKFLOW_OUTPUTS: WorkflowToolProviderOutputParameter[] = [ + { + name: 'text', + description: '', + type: VarType.string, + reserved: true, + }, + { + name: 'files', + description: '', + type: VarType.arrayFile, + reserved: true, + }, + { + name: 'json', + description: '', + type: VarType.arrayObject, + reserved: true, + }, +] + +export const isWorkflowToolNameValid = (name: string) => { + if (name === '') + return true + + return /^\w+$/.test(name) +} + +export const getReservedWorkflowOutputParameters = (t: TFunction) => { + return RESERVED_WORKFLOW_OUTPUTS.map(output => ({ + ...output, + description: output.name === 'text' + ? t('nodes.tool.outputVars.text', { ns: 'workflow' }) + : output.name === 'files' + ? t('nodes.tool.outputVars.files.title', { ns: 'workflow' }) + : t('nodes.tool.outputVars.json', { ns: 'workflow' }), + })) +} + +export const hasReservedWorkflowOutputConflict = ( + reservedOutputParameters: WorkflowToolProviderOutputParameter[], + name: string, +) => { + return reservedOutputParameters.some(parameter => parameter.name === name) +} + +export const getWorkflowOutputParameters = ( + rawOutputParameters: WorkflowToolProviderOutputParameter[], + outputSchema?: WorkflowToolProviderOutputSchema, +) => { + return buildWorkflowOutputParameters(rawOutputParameters, outputSchema) +} + +export const buildWorkflowToolRequestPayload = ({ + description, + emoji, + label, + labels, + name, + parameters, + privacyPolicy, +}: { + description: string + emoji: Emoji + label: string + labels: string[] + name: string + parameters: WorkflowToolProviderParameter[] + privacyPolicy: string +}): WorkflowToolProviderRequest & { label: string } => { + return { + name, + description, + icon: emoji, + label, + parameters: parameters.map(item => ({ + name: item.name, + description: item.description, + form: item.form, + })), + labels, + privacy_policy: privacyPolicy, + } +} diff --git a/web/app/components/tools/workflow-tool/index.tsx b/web/app/components/tools/workflow-tool/index.tsx index 23329f6a2c..219a0d8f53 100644 --- a/web/app/components/tools/workflow-tool/index.tsx +++ b/web/app/components/tools/workflow-tool/index.tsx @@ -17,9 +17,14 @@ import { toast } from '@/app/components/base/ui/toast' import LabelSelector from '@/app/components/tools/labels/selector' import ConfirmModal from '@/app/components/tools/workflow-tool/confirm-modal' import MethodSelector from '@/app/components/tools/workflow-tool/method-selector' -import { VarType } from '@/app/components/workflow/types' import { cn } from '@/utils/classnames' -import { buildWorkflowOutputParameters } from './utils' +import { + buildWorkflowToolRequestPayload, + getReservedWorkflowOutputParameters, + getWorkflowOutputParameters, + hasReservedWorkflowOutputConflict, + isWorkflowToolNameValid, +} from './helpers' export type WorkflowToolModalPayload = { icon: Emoji @@ -67,27 +72,14 @@ const WorkflowToolAsModal: FC = ({ const [parameters, setParameters] = useState(payload.parameters) const rawOutputParameters = payload.outputParameters const outputSchema = payload.tool?.output_schema - const outputParameters = useMemo(() => buildWorkflowOutputParameters(rawOutputParameters, outputSchema), [rawOutputParameters, outputSchema]) - const reservedOutputParameters: WorkflowToolProviderOutputParameter[] = [ - { - name: 'text', - description: t('nodes.tool.outputVars.text', { ns: 'workflow' }), - type: VarType.string, - reserved: true, - }, - { - name: 'files', - description: t('nodes.tool.outputVars.files.title', { ns: 'workflow' }), - type: VarType.arrayFile, - reserved: true, - }, - { - name: 'json', - description: t('nodes.tool.outputVars.json', { ns: 'workflow' }), - type: VarType.arrayObject, - reserved: true, - }, - ] + const outputParameters = useMemo( + () => getWorkflowOutputParameters(rawOutputParameters, outputSchema), + [rawOutputParameters, outputSchema], + ) + const reservedOutputParameters = useMemo( + () => getReservedWorkflowOutputParameters(t), + [t], + ) const handleParameterChange = (key: string, value: string, index: number) => { const newData = produce(parameters, (draft: WorkflowToolProviderParameter[]) => { @@ -105,18 +97,6 @@ const WorkflowToolAsModal: FC = ({ const [privacyPolicy, setPrivacyPolicy] = useState(payload.privacy_policy) const [showModal, setShowModal] = useState(false) - const isNameValid = (name: string) => { - // when the user has not input anything, no need for a warning - if (name === '') - return true - - return /^\w+$/.test(name) - } - - const isOutputParameterReserved = (name: string) => { - return reservedOutputParameters.find(p => p.name === name) - } - const onConfirm = () => { let errorMessage = '' if (!label) @@ -125,7 +105,7 @@ const WorkflowToolAsModal: FC = ({ if (!name) errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t('createTool.nameForToolCall', { ns: 'tools' }) }) - if (!isNameValid(name)) + if (!isWorkflowToolNameValid(name)) errorMessage = t('createTool.nameForToolCall', { ns: 'tools' }) + t('createTool.nameForToolCallTip', { ns: 'tools' }) if (errorMessage) { @@ -133,19 +113,15 @@ const WorkflowToolAsModal: FC = ({ return } - const requestParams = { + const requestParams = buildWorkflowToolRequestPayload({ name, description, - icon: emoji, + emoji, label, - parameters: parameters.map(item => ({ - name: item.name, - description: item.description, - form: item.form, - })), + parameters, labels, - privacy_policy: privacyPolicy, - } + privacyPolicy, + }) if (!isAdd) { onSave?.({ ...requestParams, @@ -175,7 +151,7 @@ const WorkflowToolAsModal: FC = ({
{/* name & icon */}
-
+
{t('createTool.name', { ns: 'tools' })} {' '} * @@ -192,7 +168,7 @@ const WorkflowToolAsModal: FC = ({
{/* name for tool call */}
-
+
{t('createTool.nameForToolCall', { ns: 'tools' })} {' '} * @@ -210,13 +186,13 @@ const WorkflowToolAsModal: FC = ({ value={name} onChange={e => setName(e.target.value)} /> - {!isNameValid(name) && ( + {!isWorkflowToolNameValid(name) && (
{t('createTool.nameForToolCallTip', { ns: 'tools' })}
)}
{/* description */}
-
{t('createTool.description', { ns: 'tools' })}
+
{t('createTool.description', { ns: 'tools' })}