diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index a069b6cbc7..58b4a04d1a 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -7,6 +7,7 @@ ## Summary + ## Screenshots diff --git a/.github/workflows/pyrefly-type-coverage-comment.yml b/.github/workflows/pyrefly-type-coverage-comment.yml new file mode 100644 index 0000000000..2df364953c --- /dev/null +++ b/.github/workflows/pyrefly-type-coverage-comment.yml @@ -0,0 +1,118 @@ +name: Comment with Pyrefly Type Coverage + +on: + workflow_run: + workflows: + - Pyrefly Type Coverage + types: + - completed + +permissions: {} + +jobs: + comment: + name: Comment PR with type coverage + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + issues: write + pull-requests: write + if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }} + steps: + - name: Checkout default branch (trusted code) + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Setup Python & UV + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Download type coverage artifact + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const artifacts = await github.rest.actions.listWorkflowRunArtifacts({ + owner: context.repo.owner, + repo: context.repo.repo, + run_id: ${{ github.event.workflow_run.id }}, + }); + const match = artifacts.data.artifacts.find((artifact) => + artifact.name === 'pyrefly_type_coverage' + ); + if (!match) { + throw new Error('pyrefly_type_coverage artifact not found'); + } + const download = await github.rest.actions.downloadArtifact({ + owner: context.repo.owner, + repo: context.repo.repo, + artifact_id: match.id, + archive_format: 'zip', + }); + fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data)); + + - name: Unzip artifact + run: unzip -o pyrefly_type_coverage.zip + + - name: Render coverage markdown from structured data + id: render + run: | + comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \ + --base base_report.json \ + < pr_report.json)" + + { + echo "### Pyrefly Type Coverage" + echo "" + echo "$comment_body" + } > /tmp/type_coverage_comment.md + + - name: Post comment + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' }); + let prNumber = null; + try { + prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10); + } catch (err) { + const prs = context.payload.workflow_run.pull_requests || []; + if (prs.length > 0 && prs[0].number) { + prNumber = prs[0].number; + } + } + if (!prNumber) { + throw new Error('PR number not found in artifact or workflow_run payload'); + } + + // Update existing comment if one exists, otherwise create new + const { data: comments } = await github.rest.issues.listComments({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + }); + const marker = '### Pyrefly Type Coverage'; + const existing = comments.find(c => c.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } diff --git a/.github/workflows/pyrefly-type-coverage.yml b/.github/workflows/pyrefly-type-coverage.yml new file mode 100644 index 0000000000..0c80c6a756 --- /dev/null +++ b/.github/workflows/pyrefly-type-coverage.yml @@ -0,0 +1,120 @@ +name: Pyrefly Type Coverage + +on: + pull_request: + paths: + - 'api/**/*.py' + +permissions: + contents: read + +jobs: + pyrefly-type-coverage: + runs-on: ubuntu-latest + permissions: + contents: read + issues: write + pull-requests: write + steps: + - name: Checkout PR branch + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + + - name: Setup Python & UV + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Run pyrefly report on PR branch + run: | + uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \ + mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \ + echo '{}' > /tmp/pyrefly_report_pr.json + + - name: Save helper script from base branch + run: | + git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \ + || cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py + + - name: Checkout base branch + run: git checkout ${{ github.base_ref }} + + - name: Run pyrefly report on base branch + run: | + uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \ + mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \ + echo '{}' > /tmp/pyrefly_report_base.json + + - name: Generate coverage comparison + id: coverage + run: | + comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \ + --base /tmp/pyrefly_report_base.json \ + < /tmp/pyrefly_report_pr.json)" + + { + echo "### Pyrefly Type Coverage" + echo "" + echo "$comment_body" + } | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md + + # Save structured data for the fork-PR comment workflow + cp /tmp/pyrefly_report_pr.json pr_report.json + cp /tmp/pyrefly_report_base.json base_report.json + + - name: Save PR number + run: | + echo ${{ github.event.pull_request.number }} > pr_number.txt + + - name: Upload type coverage artifact + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: pyrefly_type_coverage + path: | + pr_report.json + base_report.json + pr_number.txt + + - name: Comment PR with type coverage + if: ${{ github.event.pull_request.head.repo.full_name == github.repository }} + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const marker = '### Pyrefly Type Coverage'; + let body; + try { + body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' }); + } catch { + body = `${marker}\n\n_Coverage report unavailable._`; + } + const prNumber = context.payload.pull_request.number; + + // Update existing comment if one exists, otherwise create new + const { data: comments } = await github.rest.issues.listComments({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + }); + const existing = comments.find(c => c.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 775401bfa5..d7f007af67 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process. ## Getting Help If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. - -## Automated Agent Contributions - -> [!NOTE] -> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in. diff --git a/api/commands/account.py b/api/commands/account.py index 84af7a5ae6..6a2a2e0428 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -2,7 +2,6 @@ import base64 import secrets import click -from sqlalchemy.orm import sessionmaker from constants.languages import languages from extensions.ext_database import db @@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm): return normalized_email = email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - valid_password(new_password) - except: - click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) - return + try: + valid_password(new_password) + except: + click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) + return - # generate password salt - salt = secrets.token_bytes(16) - base64_salt = base64.b64encode(salt).decode() + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() - # encrypt password with salt - password_hashed = hash_password(new_password, salt) - base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(normalized_email) - click.echo(click.style("Password reset successfully.", fg="green")) + # encrypt password with salt + password_hashed = hash_password(new_password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account = db.session.merge(account) + account.password = base64_password_hashed + account.password_salt = base64_salt + db.session.commit() + AccountService.reset_login_error_rate_limit(normalized_email) + click.echo(click.style("Password reset successfully.", fg="green")) @click.command("reset-email", help="Reset the account email.") @@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm): return normalized_new_email = new_email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - email_validate(normalized_new_email) - except: - click.echo(click.style(f"Invalid email: {new_email}", fg="red")) - return + try: + email_validate(normalized_new_email) + except: + click.echo(click.style(f"Invalid email: {new_email}", fg="red")) + return - account.email = normalized_new_email - click.echo(click.style("Email updated successfully.", fg="green")) + account = db.session.merge(account) + account.email = normalized_new_email + db.session.commit() + click.echo(click.style("Email updated successfully.", fg="green")) @click.command("create-tenant", help="Create account and tenant.") diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 9e7faa09c5..1fd781b4fc 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,6 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -14,7 +13,6 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) -from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models import Account @@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - raise EmailAlreadyInUseError() - else: - account = self._create_new_account(normalized_email, args.password_confirm) - if not account: - raise AccountNotFoundError() - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(normalized_email) + if account: + raise EmailAlreadyInUseError() + else: + account = self._create_new_account(normalized_email, args.password_confirm) + if not account: + raise AccountNotFoundError() + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 63bc98b53f..ed390a5f89 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_reset_password_email( account=account, @@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt, session) - else: - raise AccountNotFound() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AccountNotFound() return {"result": "success"} - def _update_existing_account(self, account, password_hashed, salt, session): + def _update_existing_account(self, account, password_hashed, salt): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5c7011fd22..d31fb4a46c 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,6 @@ import urllib.parse import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(user_info.email) return account diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index f3866f6aef..e513e8c8f9 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -227,10 +227,11 @@ class ExternalApiUseCheckApi(Resource): @login_required @account_initialization_required def get(self, external_knowledge_api_id): + _, current_tenant_id = current_account_with_tenant() external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check( - external_knowledge_api_id + external_knowledge_api_id, current_tenant_id ) return {"is_using": external_knowledge_api_is_using, "count": count}, 200 diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index da957bd3e0..b33de88860 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -9,7 +9,6 @@ from flask_restx import Resource, fields, marshal_with from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import supported_language @@ -580,8 +579,7 @@ class ChangeEmailSendEmailApi(Resource): user_email = current_user.email else: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) if account is None: raise AccountNotFound() email_for_sending = account.email diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 80c3289fb4..61fd794c22 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -3,7 +3,6 @@ import secrets from flask import request from flask_restx import Resource -from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console.auth.error import ( @@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) - token = None + account = AccountService.get_account_by_email_with_case_fallback(request_email) if account is None: raise AuthenticationFailedError() else: @@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt) - else: - raise AuthenticationFailedError() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AuthenticationFailedError() return {"result": "success"} diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 0cdbb5f50a..a3fb7b4c5d 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from graphon.file import File, FileUploadConfig from graphon.model_runtime.entities.model_entities import AIModelEntity @@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel): extras: dict[str, Any] = Field(default_factory=dict) # tracing instance - trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False) + trace_manager: "TraceQueueManager | None" = Field(default=None, exclude=True, repr=False) class EasyUIBasedAppGenerateEntity(AppGenerateEntity): diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 14d1af2e8b..890f1ca319 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ -from typing import Literal, Optional +from typing import Any, Literal, TypedDict from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter -from core.tools.entities.common_entities import I18nObject +from core.tools.entities.common_entities import I18nObject, I18nObjectDict class DatasourceApiEntity(BaseModel): @@ -17,7 +17,24 @@ class DatasourceApiEntity(BaseModel): output_schema: dict | None = None -ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] +ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None + + +class DatasourceProviderApiEntityDict(TypedDict): + id: str + author: str + name: str + plugin_id: str | None + plugin_unique_identifier: str | None + description: I18nObjectDict + icon: str | dict + label: I18nObjectDict + type: str + team_credentials: dict | None + is_team_authorization: bool + allow_delete: bool + datasources: list[Any] + labels: list[str] class DatasourceProviderApiEntity(BaseModel): @@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel): def convert_none_to_empty_list(cls, v): return v if v is not None else [] - def to_dict(self) -> dict: + def to_dict(self) -> DatasourceProviderApiEntityDict: # ------------- # overwrite datasource parameter types for temp fix datasources = jsonable_encoder(self.datasources) @@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel): parameter["type"] = "files" # ------------- - return { + result: DatasourceProviderApiEntityDict = { "id": self.id, "author": self.author, "name": self.name, @@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel): "datasources": datasources, "labels": self.labels, } + return result diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 04f15dee31..c012e128f4 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -71,8 +71,8 @@ class DatasourceFileMessageTransformer: if not isinstance(message.message, DatasourceMessage.BlobMessage): raise ValueError("unexpected message type") - # FIXME: should do a type check here. - assert isinstance(message.message.blob, bytes) + if not isinstance(message.message.blob, bytes): + raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}") tool_file_manager = ToolFileManager() blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw( user_id=user_id, diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py index d8724b8de5..173913196e 100644 --- a/api/core/mcp/auth_client.py +++ b/api/core/mcp/auth_client.py @@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient): logger.exception("Authentication retry failed") raise MCPAuthError(f"Authentication retry failed: {e}") from e - def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any: + def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ Execute a function with authentication retry logic. diff --git a/api/core/mcp/entities.py b/api/core/mcp/entities.py index d6d3a677c6..21edc86a57 100644 --- a/api/core/mcp/entities.py +++ b/api/core/mcp/entities.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import StrEnum -from typing import Any, TypeVar +from typing import Any from pydantic import BaseModel @@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION] -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) -LifespanContextT = TypeVar("LifespanContextT") - @dataclass -class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]: +class RequestContext[SessionT: BaseSession, LifespanContextT]: request_id: RequestId meta: RequestParams.Meta | None session: SessionT diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 0b3aa79838..70d45b15c4 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul request: ReceiveRequestT _session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]" - _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any] + _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object] def __init__( self, @@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul request_meta: RequestParams.Meta | None, request: ReceiveRequestT, session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]", - on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], + on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object], ): self.request_id = request_id self.request_meta = request_meta diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 2653d20a7d..10e3082aa3 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -31,7 +31,6 @@ ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] RequestId = Annotated[int | str, Field(union_mode="left_to_right")] -type AnyFunction = Callable[..., Any] class RequestParams(BaseModel): diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 7a214777bc..86d042de3e 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType -from graphon.model_runtime.entities.rerank_entities import RerankResult +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -172,10 +172,10 @@ class ModelInstance: function=self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, - prompt_messages=prompt_messages, + prompt_messages=list(prompt_messages), model_parameters=model_parameters, - tools=tools, - stop=stop, + tools=list(tools) if tools else None, + stop=list(stop) if stop else None, stream=stream, callbacks=callbacks, ), @@ -193,15 +193,12 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") - return cast( - int, - self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model_name, - credentials=self.credentials, - prompt_messages=prompt_messages, - tools=tools, - ), + return self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model_name, + credentials=self.credentials, + prompt_messages=list(prompt_messages), + tools=list(tools) if tools else None, ) def invoke_text_embedding( @@ -216,15 +213,12 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - return cast( - EmbeddingResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - texts=texts, - input_type=input_type, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + texts=texts, + input_type=input_type, ) def invoke_multimodal_embedding( @@ -241,15 +235,12 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - return cast( - EmbeddingResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - multimodel_documents=multimodel_documents, - input_type=input_type, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + multimodel_documents=multimodel_documents, + input_type=input_type, ) def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: @@ -261,14 +252,11 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - return cast( - list[int], - self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model_name, - credentials=self.credentials, - texts=texts, - ), + return self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model_name, + credentials=self.credentials, + texts=texts, ) def invoke_rerank( @@ -289,23 +277,20 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") - return cast( - RerankResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, ) def invoke_multimodal_rerank( self, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, ) -> RerankResult: @@ -320,17 +305,14 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") - return cast( - RerankResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke_multimodal_rerank, - model=self.model_name, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke_multimodal_rerank, + model=self.model_name, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, ) def invoke_moderation(self, text: str) -> bool: @@ -342,14 +324,11 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, ModerationModel): raise Exception("Model type instance is not ModerationModel") - return cast( - bool, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - text=text, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + text=text, ) def invoke_speech2text(self, file: IO[bytes]) -> str: @@ -361,14 +340,11 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, Speech2TextModel): raise Exception("Model type instance is not Speech2TextModel") - return cast( - str, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - file=file, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + file=file, ) def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]: @@ -381,18 +357,15 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") - return cast( - Iterable[bytes], - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - content_text=content_text, - voice=voice, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + content_text=content_text, + voice=voice, ) - def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): + def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ Round-robin invoke :param function: function to invoke @@ -430,9 +403,8 @@ class ModelInstance: continue try: - if "credentials" in kwargs: - del kwargs["credentials"] - return function(*args, **kwargs, credentials=lb_config.credentials) + kwargs["credentials"] = lb_config.credentials + return function(*args, **kwargs) except InvokeRateLimitError as e: # expire in 60 seconds self.load_balancing_manager.cooldown(lb_config, expire=60) diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index e2d2be92cb..c76cb865c3 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Any, Union, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): ) @classmethod - def _get_user(cls, user_id: str) -> Union[EndUser, Account]: + def _get_user(cls, user_id: str) -> EndUser | Account: """ get the user by user id """ diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 3ecc9867fa..64b45bf28b 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator from sqlalchemy import Column, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -79,7 +79,7 @@ class RelytVector(BaseVector): if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" - with Session(self.client) as session: + with sessionmaker(bind=self.client).begin() as session: drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """) session.execute(drop_statement) create_statement = sql_text(f""" @@ -104,7 +104,6 @@ class RelytVector(BaseVector): $$); """) session.execute(index_statement) - session.commit() redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -208,9 +207,8 @@ class RelytVector(BaseVector): self.delete_by_uuids(ids) def delete(self): - with Session(self.client) as session: + with sessionmaker(bind=self.client).begin() as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";""")) - session.commit() def text_exists(self, id: str) -> bool: with Session(self.client) as session: diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index c948917374..e321681093 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -6,7 +6,7 @@ import sqlalchemy from pydantic import BaseModel, model_validator from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert from sqlalchemy import text as sql_text -from sqlalchemy.orm import Session, declarative_base +from sqlalchemy.orm import Session, declarative_base, sessionmaker from configs import dify_config from core.rag.datasource.vdb.field import Field, parse_metadata_json @@ -97,8 +97,7 @@ class TiDBVector(BaseVector): if redis_client.get(collection_exist_cache_key): return tidb_dist_func = self._get_distance_func() - with Session(self._engine) as session: - session.begin() + with sessionmaker(bind=self._engine).begin() as session: create_statement = sql_text(f""" CREATE TABLE IF NOT EXISTS {self._collection_name} ( id CHAR(36) PRIMARY KEY, @@ -115,7 +114,6 @@ class TiDBVector(BaseVector): ); """) session.execute(create_statement) - session.commit() redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -238,9 +236,8 @@ class TiDBVector(BaseVector): return [] def delete(self): - with Session(self._engine) as session: + with sessionmaker(bind=self._engine).begin() as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) - session.commit() def _get_distance_func(self) -> str: match self._distance_func: diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 4a731bf277..a487c49053 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -3,8 +3,7 @@ import logging import re import uuid -from collections.abc import Mapping -from typing import Any, cast +from typing import Any, TypedDict, cast logger = logging.getLogger(__name__) @@ -55,6 +54,12 @@ from services.summary_index_service import SummaryIndexService _file_access_controller = DatabaseFileAccessController() +class ParagraphFormatPreviewDict(TypedDict): + chunk_structure: str + preview: list[dict[str, Any]] + total_segments: int + + class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword = Keyword(dataset) keyword.add_texts(documents) - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict: if isinstance(chunks, list): preview = [] for content in chunks: preview.append({"content": content}) - return { + result: ParagraphFormatPreviewDict = { "chunk_structure": IndexStructureType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks), } + return result else: raise ValueError("Chunks is not a list") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 53596b5de8..2db233874a 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -3,8 +3,7 @@ import json import logging import uuid -from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict from sqlalchemy import delete, select @@ -36,6 +35,13 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) +class ParentChildFormatPreviewDict(TypedDict): + chunk_structure: str + parent_mode: str + preview: list[dict[str, Any]] + total_segments: int + + class ParentChildIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -351,17 +357,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict: parent_childs = ParentChildStructureChunk.model_validate(chunks) preview = [] for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) - return { + result: ParentChildFormatPreviewDict = { "chunk_structure": IndexStructureType.PARENT_CHILD_INDEX, "parent_mode": parent_childs.parent_mode, "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), } + return result def generate_summary_preview( self, diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 273ea0f852..b0f7928092 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,8 +4,7 @@ import logging import re import threading import uuid -from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict import pandas as pd from flask import Flask, current_app @@ -36,6 +35,12 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) +class QAFormatPreviewDict(TypedDict): + chunk_structure: str + qa_preview: list[dict[str, Any]] + total_segments: int + + class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -230,16 +235,17 @@ class QAIndexProcessor(BaseIndexProcessor): else: raise ValueError("Indexing technique must be high quality.") - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> QAFormatPreviewDict: qa_chunks = QAStructureChunk.model_validate(chunks) preview = [] for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) - return { + result: QAFormatPreviewDict = { "chunk_structure": IndexStructureType.QA_INDEX, "qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks), } + return result def generate_summary_preview( self, diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 8283be19f9..a8d37845a5 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,7 +1,7 @@ import base64 from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import RerankResult +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType @@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner): :param query_type: query type :return: rerank result """ - docs = [] + docs: list[MultimodalRerankInput] = [] doc_ids = set() unique_documents = [] for document in documents: @@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner): if upload_file: blob = storage.load_once(upload_file.key) document_file_base64 = base64.b64encode(blob).decode() - document_file_dict = { - "content": document_file_base64, - "content_type": document.metadata["doc_type"], - } - docs.append(document_file_dict) + docs.append( + MultimodalRerankInput( + content=document_file_base64, + content_type=document.metadata["doc_type"], + ) + ) else: - document_text_dict = { - "content": document.page_content, - "content_type": document.metadata.get("doc_type") or DocType.TEXT, - } - docs.append(document_text_dict) + docs.append( + MultimodalRerankInput( + content=document.page_content, + content_type=document.metadata.get("doc_type") or DocType.TEXT, + ) + ) doc_ids.add(document.metadata["doc_id"]) unique_documents.append(document) elif document.provider == "external": if document not in unique_documents: docs.append( - { - "content": document.page_content, - "content_type": document.metadata.get("doc_type") or DocType.TEXT, - } + MultimodalRerankInput( + content=document.page_content, + content_type=document.metadata.get("doc_type") or DocType.TEXT, + ) ) unique_documents.append(document) @@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner): if upload_file: blob = storage.load_once(upload_file.key) file_query = base64.b64encode(blob).decode() - file_query_dict = { - "content": file_query, - "content_type": DocType.IMAGE, - } + file_query_input = MultimodalRerankInput( + content=file_query, + content_type=DocType.IMAGE, + ) rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( - query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n + query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents else: diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index a59d167a0a..d8674b3af9 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -6,7 +6,6 @@ import os import time from collections.abc import Generator from mimetypes import guess_extension, guess_type -from typing import Union from uuid import uuid4 import httpx @@ -158,7 +157,7 @@ class ToolFileManager: return tool_file - def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary(self, id: str) -> tuple[bytes, str] | None: """ get file binary @@ -176,7 +175,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None: """ get file binary diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 2593e381cf..b4252e1a3e 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,7 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Protocol, cast import sqlalchemy as sa from graphon.runtime import VariablePool @@ -100,7 +100,7 @@ class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} _builtin_providers_loaded = False - _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} + _builtin_tools_labels: dict[str, I18nObject | None] = {} @classmethod def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: @@ -190,7 +190,7 @@ class ToolManager: invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, credential_id: str | None = None, - ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: + ) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool: """ get the tool runtime @@ -398,7 +398,7 @@ class ToolManager: agent_tool: AgentToolEntity, user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional["VariablePool"] = None, + variable_pool: "VariablePool | None" = None, ) -> Tool: """ get the agent tool runtime @@ -442,7 +442,7 @@ class ToolManager: workflow_tool: WorkflowToolRuntimeSpec, user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional["VariablePool"] = None, + variable_pool: "VariablePool | None" = None, ) -> Tool: """ get the workflow tool runtime @@ -634,7 +634,7 @@ class ToolManager: cls._builtin_providers_loaded = False @classmethod - def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: + def get_tool_label(cls, tool_name: str) -> I18nObject | None: """ get the tool label @@ -993,7 +993,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str: + def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str: try: with Session(db.engine) as session: mcp_service = MCPToolManageService(session=session) @@ -1001,7 +1001,7 @@ class ToolManager: mcp_provider = mcp_service.get_provider_entity( provider_id=provider_id, tenant_id=tenant_id, by_server_id=True ) - return mcp_provider.provider_icon + return cast(EmojiIconDict | str, mcp_provider.provider_icon) except ValueError: raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") except Exception: @@ -1013,7 +1013,7 @@ class ToolManager: tenant_id: str, provider_type: ToolProviderType, provider_id: str, - ) -> str | EmojiIconDict | dict[str, str]: + ) -> str | EmojiIconDict: """ get the tool icon @@ -1052,7 +1052,7 @@ class ToolManager: def _convert_tool_parameters_type( cls, parameters: list[ToolParameter], - variable_pool: Optional["VariablePool"], + variable_pool: "VariablePool | None", tool_configurations: Mapping[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index bb5b3ba76e..2264981abd 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -118,7 +118,8 @@ class ToolFileMessageTransformer: if not isinstance(message.message, ToolInvokeMessage.BlobMessage): raise ValueError("unexpected message type") - assert isinstance(message.message.blob, bytes) + if not isinstance(message.message.blob, bytes): + raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}") tool_file_manager = ToolFileManager() tool_file = tool_file_manager.create_file_by_raw( user_id=user_id, diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index b9e592cadb..a619b9342d 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -14,6 +14,7 @@ from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection from redis.retry import Retry from redis.sentinel import Sentinel +from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp @@ -126,6 +127,35 @@ redis_client: RedisClientWrapper = RedisClientWrapper() _pubsub_redis_client: redis.Redis | RedisCluster | None = None +class RedisSSLParamsDict(TypedDict): + ssl_cert_reqs: int + ssl_ca_certs: str | None + ssl_certfile: str | None + ssl_keyfile: str | None + + +class RedisHealthParamsDict(TypedDict): + retry: Retry + socket_timeout: float | None + socket_connect_timeout: float | None + health_check_interval: int | None + + +class RedisBaseParamsDict(TypedDict): + username: str | None + password: str | None + db: int + encoding: str + encoding_errors: str + decode_responses: bool + protocol: int + cache_config: CacheConfig | None + retry: Retry + socket_timeout: float | None + socket_connect_timeout: float | None + health_check_interval: int | None + + def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: """Get SSL configuration for Redis connection.""" if not dify_config.REDIS_USE_SSL: @@ -171,14 +201,14 @@ def _get_retry_policy() -> Retry: ) -def _get_connection_health_params() -> dict[str, Any]: +def _get_connection_health_params() -> RedisHealthParamsDict: """Get connection health and retry parameters for standalone and Sentinel Redis clients.""" - return { - "retry": _get_retry_policy(), - "socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT, - "socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, - "health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL, - } + return RedisHealthParamsDict( + retry=_get_retry_policy(), + socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT, + socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, + health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL, + ) def _get_cluster_connection_health_params() -> dict[str, Any]: @@ -189,26 +219,26 @@ def _get_cluster_connection_health_params() -> dict[str, Any]: here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout`` are passed through. """ - params = _get_connection_health_params() + params: dict[str, Any] = dict(_get_connection_health_params()) return {k: v for k, v in params.items() if k != "health_check_interval"} -def _get_base_redis_params() -> dict[str, Any]: +def _get_base_redis_params() -> RedisBaseParamsDict: """Get base Redis connection parameters including retry and health policy.""" - return { - "username": dify_config.REDIS_USERNAME, - "password": dify_config.REDIS_PASSWORD or None, - "db": dify_config.REDIS_DB, - "encoding": "utf-8", - "encoding_errors": "strict", - "decode_responses": False, - "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, - "cache_config": _get_cache_configuration(), + return RedisBaseParamsDict( + username=dify_config.REDIS_USERNAME, + password=dify_config.REDIS_PASSWORD or None, + db=dify_config.REDIS_DB, + encoding="utf-8", + encoding_errors="strict", + decode_responses=False, + protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, + cache_config=_get_cache_configuration(), **_get_connection_health_params(), - } + ) -def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: +def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]: """Create Redis client using Sentinel configuration.""" if not dify_config.REDIS_SENTINELS: raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True") @@ -232,7 +262,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, sentinel_kwargs=sentinel_kwargs, ) - master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) + params: dict[str, Any] = {**redis_params} + master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params) return master @@ -259,18 +290,16 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: return cluster -def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: +def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]: """Create standalone Redis client.""" connection_class, ssl_kwargs = _get_ssl_configuration() - params = {**redis_params} - params.update( - { - "host": dify_config.REDIS_HOST, - "port": dify_config.REDIS_PORT, - "connection_class": connection_class, - } - ) + params: dict[str, Any] = { + **redis_params, + "host": dify_config.REDIS_HOST, + "port": dify_config.REDIS_PORT, + "connection_class": connection_class, + } if dify_config.REDIS_MAX_CONNECTIONS: params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS @@ -293,8 +322,8 @@ def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | kwargs["max_connections"] = max_conns return RedisCluster.from_url(pubsub_url, **kwargs) - health_params = _get_connection_health_params() - kwargs = {**health_params} + standalone_health_params: dict[str, Any] = dict(_get_connection_health_params()) + kwargs = {**standalone_health_params} if max_conns: kwargs["max_connections"] = max_conns return redis.Redis.from_url(pubsub_url, **kwargs) diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py index 1dd92caeae..ad83826427 100644 --- a/api/extensions/otel/decorators/base.py +++ b/api/extensions/otel/decorators/base.py @@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab handler = _get_handler_instance(handler_class or SpanHandler) tracer = get_tracer(__name__) - return handler.wrapper( - tracer=tracer, - wrapped=func, - args=args, - kwargs=kwargs, - ) + return handler.wrapper(tracer, func, *args, **kwargs) return cast(Callable[P, R], wrapper) diff --git a/api/extensions/otel/decorators/handler.py b/api/extensions/otel/decorators/handler.py index e465a615a6..b0d9fa7af6 100644 --- a/api/extensions/otel/decorators/handler.py +++ b/api/extensions/otel/decorators/handler.py @@ -1,8 +1,8 @@ import inspect -from collections.abc import Callable, Mapping +from collections.abc import Callable from typing import Any -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer class SpanHandler: @@ -16,9 +16,9 @@ class SpanHandler: exceptions. Handlers can override the wrapper method to customize behavior. """ - _signature_cache: dict[Callable[..., Any], inspect.Signature] = {} + _signature_cache: dict[Callable[..., object], inspect.Signature] = {} - def _build_span_name(self, wrapped: Callable[..., Any]) -> str: + def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str: """ Build the span name from the wrapped function. @@ -29,11 +29,11 @@ class SpanHandler: """ return f"{wrapped.__module__}.{wrapped.__qualname__}" - def _extract_arguments[T]( + def _extract_arguments[**P, R]( self, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, ) -> dict[str, Any] | None: """ Extract function arguments using inspect.signature. @@ -59,13 +59,13 @@ class SpanHandler: except Exception: return None - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: """ Fully control the wrapper behavior. diff --git a/api/extensions/otel/decorators/handlers/generate_handler.py b/api/extensions/otel/decorators/handlers/generate_handler.py index cc6c75304f..df5142c310 100644 --- a/api/extensions/otel/decorators/handlers/generate_handler.py +++ b/api/extensions/otel/decorators/handlers/generate_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -15,15 +14,15 @@ logger = logging.getLogger(__name__) class AppGenerateHandler(SpanHandler): """Span handler for ``AppGenerateService.generate``.""" - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py index 8abd60197c..6b2112ceb2 100644 --- a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py +++ b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -14,15 +13,15 @@ logger = logging.getLogger(__name__) class WorkflowAppRunnerHandler(SpanHandler): """Span handler for ``WorkflowAppRunner.run``.""" - def wrapper( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., Any], - args: tuple[Any, ...], - kwargs: Mapping[str, Any], - ) -> Any: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/libs/db_migration_lock.py b/api/libs/db_migration_lock.py index 1d3a81e0a2..ca8956e397 100644 --- a/api/libs/db_migration_lock.py +++ b/api/libs/db_migration_lock.py @@ -14,9 +14,15 @@ from __future__ import annotations import logging import threading -from typing import Any +from typing import TYPE_CHECKING, Any +import redis +from redis.cluster import RedisCluster from redis.exceptions import LockNotOwnedError, RedisError +from redis.lock import Lock + +if TYPE_CHECKING: + from extensions.ext_redis import RedisClientWrapper logger = logging.getLogger(__name__) @@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock: primary error/exit code. """ - _redis_client: Any + _redis_client: redis.Redis | RedisCluster | RedisClientWrapper _name: str _ttl_seconds: float _renew_interval_seconds: float _log_context: str | None _logger: logging.Logger - _lock: Any + _lock: Lock | None _stop_event: threading.Event | None _thread: threading.Thread | None _acquired: bool def __init__( self, - redis_client: Any, + redis_client: redis.Redis | RedisCluster | RedisClientWrapper, name: str, ttl_seconds: float = 60, renew_interval_seconds: float | None = None, @@ -127,7 +133,7 @@ class DbMigrationAutoRenewLock: ) self._thread.start() - def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None: + def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None: while not stop_event.wait(self._renew_interval_seconds): try: lock.reacquire() diff --git a/api/libs/helper.py b/api/libs/helper.py index ece53e8806..e7decd43b3 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -10,7 +10,7 @@ import uuid from collections.abc import Callable, Generator, Mapping from datetime import datetime from hashlib import sha256 -from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast +from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast from uuid import UUID from zoneinfo import available_timezones @@ -81,7 +81,7 @@ def escape_like_pattern(pattern: str) -> str: return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") -def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: +def extract_tenant_id(user: "Account | EndUser") -> str | None: """ Extract tenant_id from Account or EndUser object. @@ -164,7 +164,10 @@ def email(email): EmailStr = Annotated[str, AfterValidator(email)] -def uuid_value(value: Any) -> str: +def uuid_value(value: str | UUID) -> str: + if isinstance(value, UUID): + return str(value) + if value == "": return str(value) @@ -405,7 +408,7 @@ class TokenManager: def generate_token( cls, token_type: str, - account: Optional["Account"] = None, + account: "Account | None" = None, email: str | None = None, additional_data: dict | None = None, ) -> str: @@ -465,9 +468,7 @@ class TokenManager: return current_token @classmethod - def _set_current_token_for_account( - cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float] - ): + def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_minutes: int | float): key = cls._get_account_token_key(account_id, token_type) expiry_seconds = int(expiry_minutes * 60) redis_client.setex(key, expiry_seconds, token) diff --git a/api/libs/pyrefly_type_coverage.py b/api/libs/pyrefly_type_coverage.py new file mode 100644 index 0000000000..369b8dff3c --- /dev/null +++ b/api/libs/pyrefly_type_coverage.py @@ -0,0 +1,145 @@ +"""Helpers for generating type-coverage summaries from pyrefly report output.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import TypedDict + + +class CoverageSummary(TypedDict): + n_modules: int + n_typable: int + n_typed: int + n_any: int + n_untyped: int + coverage: float + strict_coverage: float + + +_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__) + +_EMPTY_SUMMARY: CoverageSummary = { + "n_modules": 0, + "n_typable": 0, + "n_typed": 0, + "n_any": 0, + "n_untyped": 0, + "coverage": 0.0, + "strict_coverage": 0.0, +} + + +def parse_summary(report_json: str) -> CoverageSummary: + """Extract the summary section from ``pyrefly report`` JSON output. + + Returns an empty summary when *report_json* is empty or malformed so that + the CI workflow can degrade gracefully instead of crashing. + """ + if not report_json or not report_json.strip(): + return _EMPTY_SUMMARY.copy() + + try: + data = json.loads(report_json) + except json.JSONDecodeError: + return _EMPTY_SUMMARY.copy() + + summary = data.get("summary") + if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary): + return _EMPTY_SUMMARY.copy() + + return { + "n_modules": summary["n_modules"], + "n_typable": summary["n_typable"], + "n_typed": summary["n_typed"], + "n_any": summary["n_any"], + "n_untyped": summary["n_untyped"], + "coverage": summary["coverage"], + "strict_coverage": summary["strict_coverage"], + } + + +def format_summary_markdown(summary: CoverageSummary) -> str: + """Format a single coverage summary as a Markdown table.""" + + return ( + "| Metric | Value |\n" + "| --- | ---: |\n" + f"| Modules | {summary['n_modules']} |\n" + f"| Typable symbols | {summary['n_typable']:,} |\n" + f"| Typed symbols | {summary['n_typed']:,} |\n" + f"| Untyped symbols | {summary['n_untyped']:,} |\n" + f"| Any symbols | {summary['n_any']:,} |\n" + f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n" + f"| Strict coverage | {summary['strict_coverage']:.2f}% |" + ) + + +def format_comparison_markdown( + base: CoverageSummary, + pr: CoverageSummary, +) -> str: + """Format a comparison between base and PR coverage as Markdown.""" + + coverage_delta = pr["coverage"] - base["coverage"] + strict_delta = pr["strict_coverage"] - base["strict_coverage"] + typed_delta = pr["n_typed"] - base["n_typed"] + untyped_delta = pr["n_untyped"] - base["n_untyped"] + + def _fmt_delta(value: float, fmt: str = ".2f") -> str: + sign = "+" if value > 0 else "" + return f"{sign}{value:{fmt}}" + + lines = [ + "| Metric | Base | PR | Delta |", + "| --- | ---: | ---: | ---: |", + (f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"), + ( + f"| Strict coverage | {base['strict_coverage']:.2f}% " + f"| {pr['strict_coverage']:.2f}% " + f"| {_fmt_delta(strict_delta)}% |" + ), + (f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"), + (f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"), + ( + f"| Modules | {base['n_modules']} " + f"| {pr['n_modules']} " + f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |" + ), + ] + return "\n".join(lines) + + +def main() -> int: + """Read pyrefly report JSON from stdin and print a Markdown summary. + + Accepts an optional ``--base `` argument. When provided, the output + includes a base-vs-PR comparison table. + """ + + args = sys.argv[1:] + + base_file: str | None = None + if "--base" in args: + idx = args.index("--base") + if idx + 1 >= len(args): + sys.stderr.write("error: --base requires a file path\n") + return 1 + base_file = args[idx + 1] + + pr_report = sys.stdin.read() + pr_summary = parse_summary(pr_report) + + if base_file is not None: + base_text = Path(base_file).read_text() if Path(base_file).exists() else "" + base_summary = parse_summary(base_text) + sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n") + else: + sys.stdout.write(format_summary_markdown(pr_summary) + "\n") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/api/models/base.py b/api/models/base.py index b7023b9c8b..5acdf184f4 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase): class DefaultFieldsMixin: + """Mixin for models that inherit from Base (non-dataclass).""" + id: Mapped[str] = mapped_column( StringUUID, primary_key=True, @@ -53,6 +55,42 @@ class DefaultFieldsMixin: return f"<{self.__class__.__name__}(id={self.id})>" +class DefaultFieldsDCMixin(MappedAsDataclass): + """Mixin for models that inherit from TypeBase (MappedAsDataclass).""" + + __abstract__ = True + + id: Mapped[str] = mapped_column( + StringUUID, + primary_key=True, + insert_default=lambda: str(uuidv7()), + default_factory=lambda: str(uuidv7()), + init=False, + ) + + created_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + insert_default=naive_utc_now, + default_factory=naive_utc_now, + init=False, + server_default=func.current_timestamp(), + ) + + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + insert_default=naive_utc_now, + default_factory=naive_utc_now, + init=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(id={self.id})>" + + def gen_uuidv4_string() -> str: """gen_uuidv4_string generate a UUIDv4 string. diff --git a/api/models/model.py b/api/models/model.py index d2ff8065e2..0ea2259a19 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -913,11 +913,7 @@ class TrialApp(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( - sa.DateTime, - nullable=False, - insert_default=func.current_timestamp(), - server_default=func.current_timestamp(), - init=False, + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3) @@ -941,11 +937,7 @@ class AccountTrialAppRecord(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) created_at: Mapped[datetime] = mapped_column( - sa.DateTime, - nullable=False, - insert_default=func.current_timestamp(), - server_default=func.current_timestamp(), - init=False, + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @property diff --git a/api/models/workflow.py b/api/models/workflow.py index 347804b091..aca9b7ac70 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -4,7 +4,7 @@ import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa @@ -121,7 +121,7 @@ class WorkflowType(StrEnum): raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": + def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType": """ Get workflow type from app mode. @@ -1051,7 +1051,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo ) return extras - def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: + def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None": return next(iter([i for i in self.offload_data if i.type_ == type_]), None) @property diff --git a/api/services/account_service.py b/api/services/account_service.py index 4b58b3b697..1f5f81e5bd 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast from pydantic import BaseModel, TypeAdapter from sqlalchemy import delete, func, select, update -from sqlalchemy.orm import Session, sessionmaker + +from core.db.session_factory import session_factory class InvitationData(TypedDict): @@ -800,19 +801,19 @@ class AccountService: return token @staticmethod - def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None: + def get_account_by_email_with_case_fallback(email: str) -> Account | None: """ Retrieve an account by email and fall back to the lowercase email if the original lookup fails. This keeps backward compatibility for older records that stored uppercase emails while the rest of the system gradually normalizes new inputs. """ - query_session = session or db.session - account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account or email == email.lower(): - return account + with session_factory.create_session() as session: + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + if account or email == email.lower(): + return account - return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() + return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: @@ -1516,8 +1517,7 @@ class RegisterService: check_workspace_member_invite_permission(tenant.id) - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: TenantService.check_member_permission(tenant, inviter, None, "add") diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 17ed98d301..5e8c7aa337 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -4,7 +4,7 @@ import logging import threading import uuid from collections.abc import Callable, Generator, Mapping -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator @@ -88,7 +88,7 @@ class AppGenerateService: def generate( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, @@ -356,11 +356,11 @@ class AppGenerateService: def generate_more_like_this( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, message_id: str, invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[Mapping, Generator]: + ) -> Mapping | Generator: """ Generate more like this :param app_model: app model diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index 55ae1e03b1..a731d5c048 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac import json from datetime import UTC, datetime -from typing import Any, Union +from typing import Any from celery.result import AsyncResult from sqlalchemy import select @@ -50,7 +50,7 @@ class AsyncWorkflowService: @classmethod def trigger_workflow_async( - cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData + cls, session: Session, user: Account | EndUser, trigger_data: TriggerData ) -> AsyncTriggerResponse: """ Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK @@ -177,7 +177,7 @@ class AsyncWorkflowService: @classmethod def reinvoke_trigger( - cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str + cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str ) -> AsyncTriggerResponse: """ Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3e952059ac..9c71902849 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2822,6 +2822,10 @@ class DocumentService: knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values()) + if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL: + if not knowledge_config.process_rule.rules.parent_mode: + knowledge_config.process_rule.rules.parent_mode = "paragraph" + if not knowledge_config.process_rule.rules.segmentation: raise ValueError("Process rule segmentation is required") diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index d30ec940f5..ce5dee4943 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from typing import Any, Union, cast +from typing import Any, cast from urllib.parse import urlparse import httpx @@ -148,18 +148,23 @@ class ExternalDatasetService: db.session.commit() @staticmethod - def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: + def external_knowledge_api_use_check(external_knowledge_api_id: str, tenant_id: str) -> tuple[bool, int]: + """ + Return usage for an external knowledge API within a single tenant. + + The caller already scopes access by tenant, so this query must do the + same; otherwise the endpoint becomes a cross-tenant UUID oracle. + """ count = ( db.session.scalar( select(func.count(ExternalKnowledgeBindings.id)).where( - ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id + ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id, + ExternalKnowledgeBindings.tenant_id == tenant_id, ) ) or 0 ) - if count > 0: - return True, count - return False, 0 + return count > 0, count @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: @@ -190,9 +195,7 @@ class ExternalDatasetService: raise ValueError(f"{parameter.get('name')} is required") @staticmethod - def process_external_api( - settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]] - ) -> httpx.Response: + def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response: """ do http request depending on api bundle """ diff --git a/api/services/file_service.py b/api/services/file_service.py index 7443ca3271..79a935de4b 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Iterator, Sequence from contextlib import contextmanager, suppress from tempfile import NamedTemporaryFile -from typing import Literal, Union +from typing import Literal from zipfile import ZIP_DEFLATED, ZipFile from graphon.file import helpers as file_helpers @@ -52,7 +52,7 @@ class FileService: filename: str, content: bytes, mimetype: str, - user: Union[Account, EndUser], + user: Account | EndUser, source: Literal["datasets"] | None = None, source_url: str = "", ) -> UploadFile: diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 3cce83a975..41b6b885b2 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, TypedDict, Union +from typing import Any, TypedDict from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -626,7 +626,7 @@ class ModelLoadBalancingService: def _get_credential_schema( self, provider_configuration: ProviderConfiguration - ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: + ) -> ModelCredentialSchema | ProviderCredentialSchema: """Get form schemas.""" if provider_configuration.provider.model_credential_schema: return provider_configuration.provider.model_credential_schema diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py index 0d2a70acbd..3cca4268d0 100644 --- a/api/services/plugin/plugin_permission_service.py +++ b/api/services/plugin/plugin_permission_service.py @@ -1,14 +1,13 @@ from sqlalchemy import select -from sqlalchemy.orm import sessionmaker -from extensions.ext_database import db +from core.db.session_factory import session_factory from models.account import TenantPluginPermission class PluginPermissionService: @staticmethod def get_permission(tenant_id: str) -> TenantPluginPermission | None: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: return session.scalar( select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) ) @@ -19,7 +18,7 @@ class PluginPermissionService: install_permission: TenantPluginPermission.InstallPermission, debug_permission: TenantPluginPermission.DebugPermission, ): - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session, session.begin(): permission = session.scalar( select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) ) diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 10e89b1dba..56bc785958 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Union +from typing import Any from configs import dify_config from core.app.apps.pipeline.pipeline_generator import PipelineGenerator @@ -17,7 +17,7 @@ class PipelineGenerateService: def generate( cls, pipeline: Pipeline, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f6d80f9a6e..5fc5b412b3 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -5,7 +5,7 @@ import threading import time from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Union, cast +from typing import Any, cast from uuid import uuid4 from flask_login import current_user @@ -1387,7 +1387,7 @@ class RagPipelineService: "uninstalled_recommended_plugins": uninstalled_plugin_list, } - def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]): + def retry_error_document(self, dataset: Dataset, document: Document, user: Account | EndUser): """ Retry error document """ diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 4fd2ea1628..72954a3102 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any, Union +from typing import Any from pydantic import TypeAdapter, ValidationError from yarl import URL @@ -69,7 +69,7 @@ class ToolTransformService: return "" @staticmethod - def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]): + def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity): """ repack provider diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index 0a73c91279..45e1f80e35 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -7,15 +7,16 @@ with appropriate retry policies and error handling. import logging from datetime import UTC, datetime -from typing import Any +from typing import Any, NotRequired from celery import shared_task from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from typing_extensions import TypedDict from configs import dify_config -from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.app.layers.timeslice_layer import TimeSliceLayer @@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf logger = logging.getLogger(__name__) +class WorkflowGeneratorArgsDict(TypedDict): + inputs: dict[str, Any] + files: list[Any] + _skip_prepare_user_inputs: bool + workflow_id: NotRequired[str] + + @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE) def execute_workflow_professional(task_data_dict: dict[str, Any]): """Execute workflow for professional tier with highest priority""" @@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]): ) -def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]: +def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict: """Build args passed into WorkflowAppGenerator.generate for Celery executions.""" - - args: dict[str, Any] = { + return { "inputs": dict(trigger_data.inputs), "files": list(trigger_data.files), - SKIP_PREPARE_USER_INPUTS_KEY: True, + "_skip_prepare_user_inputs": True, } - return args def _execute_workflow_common( diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 879c337319..320da85b60 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 7b7393dade..d2703ed5cc 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") def test_reset_fetches_account_with_original_email( self, mock_get_reset_data, mock_revoke_token, + mock_db, mock_get_account, mock_update_account, app, @@ -126,6 +128,7 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_account = MagicMock() mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account wraps_features = SimpleNamespace(enable_email_password_login=True) with ( @@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index a2f1328579..1eabb45422 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -437,7 +437,10 @@ class TestAccountGeneration: second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 8f9db287e3..50249bcd74 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -335,10 +335,12 @@ class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") def test_reset_password_success( self, mock_get_tenants, + mock_db, mock_get_account, mock_revoke_token, mock_get_data, @@ -356,6 +358,7 @@ class TestForgotPasswordResetApi: # Arrange mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] # Act diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 04ad143103..f14b2c0ae5 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi: @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") - @patch("controllers.web.forgot_password.sessionmaker") def test_should_normalize_email_before_sending( self, - mock_session_cls, mock_extract_ip, mock_rate_limit, mock_get_account, @@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi: mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_mail.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password", - method="POST", - json={"email": "User@Example.com", "language": "zh-Hans"}, - ): - response = ForgotPasswordSendEmailApi().post() + with app.test_request_context( + "/web/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") mock_extract_ip.assert_called_once() mock_rate_limit.assert_called_once_with("127.0.0.1") @@ -153,14 +148,14 @@ class TestForgotPasswordResetApi: @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") def test_should_fetch_account_with_fallback( self, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_get_account, mock_update_account, app, @@ -168,29 +163,27 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = mock_account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "token-123", - "new_password": "ValidPass123!", - "password_confirm": "ValidPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_update_account.assert_called_once() mock_revoke_token.assert_called_once_with("token-123") @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -199,7 +192,7 @@ class TestForgotPasswordResetApi: mock_get_account, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_token_bytes, mock_hash_password, app, @@ -207,20 +200,18 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} account = MagicMock() mock_get_account.return_value = account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "reset-token", - "new_password": "StrongPass123!", - "password_confirm": "StrongPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "reset-token", + "new_password": "StrongPass123!", + "password_confirm": "StrongPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("reset-token") diff --git a/api/tests/unit_tests/services/test_hit_testing_service.py b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py similarity index 51% rename from api/tests/unit_tests/services/test_hit_testing_service.py rename to api/tests/test_containers_integration_tests/services/test_hit_testing_service.py index 80e9729f5b..f332ba05ec 100644 --- a/api/tests/unit_tests/services/test_hit_testing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py @@ -1,239 +1,193 @@ +from __future__ import annotations + import json from typing import Any, cast from unittest.mock import ANY, MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.rag.models.document import Document -from models.dataset import Dataset +from models.dataset import Dataset, DatasetQuery from services.hit_testing_service import HitTestingService -class TestHitTestingService: - """Test suite for HitTestingService""" +def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset: + tenant_id = str(uuid4()) + created_by = str(uuid4()) + ds = Dataset( + tenant_id=kwargs.get("tenant_id", tenant_id), + name=kwargs.get("name", "test-dataset"), + created_by=kwargs.get("created_by", created_by), + provider=provider, + ) + db_session.add(ds) + db_session.commit() + db_session.refresh(ds) + return ds - # ===== Utility Method Tests ===== + +class TestHitTestingService: + # ── Utility methods (pure logic, no DB) ──────────────────────────── def test_escape_query_for_search_should_escape_double_quotes(self): - """Test that escape_query_for_search escapes double quotes correctly""" - # Arrange query = 'test "query" with quotes' - expected = 'test \\"query\\" with quotes' - - # Act result = HitTestingService.escape_query_for_search(query) - - # Assert - assert result == expected + assert result == 'test \\"query\\" with quotes' def test_hit_testing_args_check_should_pass_with_valid_query(self): - """Test that hit_testing_args_check passes with a valid query""" - # Arrange - args = {"query": "valid query"} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"query": "valid query"}) def test_hit_testing_args_check_should_pass_with_valid_attachments(self): - """Test that hit_testing_args_check passes with valid attachment_ids""" - # Arrange - args = {"attachment_ids": ["id1", "id2"]} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]}) def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): - """Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing""" - # Arrange - args = {} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query or attachment_ids is required" in str(exc_info.value) + with pytest.raises(ValueError, match="Query or attachment_ids is required"): + HitTestingService.hit_testing_args_check({}) def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): - """Test that hit_testing_args_check raises ValueError if query exceeds 250 characters""" - # Arrange - args = {"query": "a" * 251} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query cannot exceed 250 characters" in str(exc_info.value) + with pytest.raises(ValueError, match="Query cannot exceed 250 characters"): + HitTestingService.hit_testing_args_check({"query": "a" * 251}) def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): - """Test that hit_testing_args_check raises ValueError if attachment_ids is not a list""" - # Arrange - args = {"attachment_ids": "not a list"} + with pytest.raises(ValueError, match="Attachment_ids must be a list"): + HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"}) - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Attachment_ids must be a list" in str(exc_info.value) - - # ===== Response Formatting Tests ===== + # ── Response formatting ──────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") def test_compact_retrieve_response_should_format_correctly(self, mock_format): - """Test that compact_retrieve_response formats the response correctly""" - # Arrange query = "test query" mock_doc = MagicMock(spec=Document) - documents = [mock_doc] mock_record = MagicMock() mock_record.model_dump.return_value = {"content": "formatted content"} mock_format.return_value = [mock_record] - # Act - result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents)) + result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc])) - # Assert assert cast(dict[str, Any], result["query"])["content"] == query assert len(result["records"]) == 1 assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" - mock_format.assert_called_once_with(documents) + mock_format.assert_called_once_with([mock_doc]) - def test_compact_external_retrieve_response_should_return_records_for_external_provider(self): - """Test that compact_external_retrieve_response returns records when dataset provider is external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "external" - query = "test query" + def test_compact_external_retrieve_response_should_return_records_for_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") documents = [ {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, ] - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents) + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert len(result["records"]) == 2 assert cast(dict[str, Any], result["records"][0])["content"] == "c1" assert cast(dict[str, Any], result["records"][1])["title"] == "t2" - def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self): - """Test that compact_external_retrieve_response returns empty records for non-external provider""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" - documents = [{"content": "c1"}] + def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="vendor") - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], + HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]), + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== External Retrieve Tests ===== + # ── External retrieve (real DB) ──────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve): - """Test that external_retrieve successfully retrieves from external provider and commits query""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - dataset.provider = "external" - query = 'test "query"' + def test_external_retrieve_should_succeed_for_external_provider( + self, mock_ext_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") + account_id = str(uuid4()) account = MagicMock() - account.id = "account_id" - + account.id = account_id mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.external_retrieve( dataset=dataset, - query=query, + query='test "query"', account=account, external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == 'test "query"' assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" - - # Verify call to RetrievalService.external_retrieve with escaped query mock_ext_retrieve.assert_called_once_with( - dataset_id="dataset_id", + dataset_id=dataset.id, query='test \\"query\\"', external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ) - # Verify DatasetQuery record was added and committed - mock_add.assert_called_once() - mock_commit.assert_called_once() + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 - def test_external_retrieve_should_return_empty_for_non_external_provider(self): - """Test that external_retrieve returns empty results immediately if provider is not external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" + def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers, provider="vendor") account = MagicMock() - # Act - result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account)) + result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account)) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== Retrieve Tests ===== + # ── Retrieve (real DB) ───────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve uses default model when retrieval_model is not provided""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" + def test_retrieve_should_use_default_model_when_none_provided( + self, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) dataset.retrieval_model = None - query = "test query" account = MagicMock() - account.id = "account_id" - + account.id = str(uuid4()) mock_retrieve.return_value = [] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={} + dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={} ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" mock_retrieve.assert_called_once() - # Verify top_k from default_retrieval_model (4) assert mock_retrieve.call_args.kwargs["top_k"] == 4 - mock_commit.assert_called_once() + + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve): - """Test that retrieve correctly calls metadata filtering when conditions are present""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_metadata_filtering( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "semantic_search", @@ -242,29 +196,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response - mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string") + mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string") mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_get_meta.assert_called_once() mock_retrieve.assert_called_once() assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"] @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve): - """Test that retrieve returns empty response if metadata filtering returns condition but no document IDs""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_return_empty_if_metadata_filtering_fails( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() retrieval_model = { @@ -274,37 +226,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response: condition returned but no IDs mock_get_meta.return_value = ({}, "condition_string") - # Act result = cast( dict[str, Any], HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, ), ) - # Assert assert result["records"] == [] mock_retrieve.assert_not_called() @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve handles attachment_ids and adds them to DatasetQuery""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) attachment_ids = ["att1", "att2"] retrieval_model = { @@ -315,21 +257,19 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, attachment_ids=attachment_ids, ) - # Assert mock_retrieve.assert_called_once_with( retrieval_method=ANY, - dataset_id="dataset_id", - query=query, + dataset_id=dataset.id, + query="test query", attachment_ids=attachment_ids, top_k=4, score_threshold=0.0, @@ -338,26 +278,27 @@ class TestHitTestingService: weights=None, document_ids_filter=None, ) - # Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images) - # The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}]) - called_query = mock_add.call_args[0][0] - query_content = json.loads(called_query.content) + + # Verify DatasetQuery was persisted with correct content structure + db_session_with_containers.expire_all() + latest = db_session_with_containers.scalar( + select(DatasetQuery) + .where(DatasetQuery.dataset_id == dataset.id) + .order_by(DatasetQuery.created_at.desc()) + .limit(1) + ) + assert latest is not None + query_content = json.loads(latest.content) assert len(query_content) == 3 # 1 text + 2 images assert query_content[0]["content_type"] == "text_query" assert query_content[1]["content_type"] == "image_query" assert query_content[1]["content"] == "att1" @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve passes reranking and threshold parameters correctly""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "hybrid_search", @@ -371,12 +312,14 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_retrieve.assert_called_once() kwargs = mock_retrieve.call_args.kwargs assert kwargs["score_threshold"] == 0.5 diff --git a/api/tests/test_containers_integration_tests/services/test_ops_service.py b/api/tests/test_containers_integration_tests/services/test_ops_service.py new file mode 100644 index 0000000000..e2e1a228b2 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_ops_service.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import uuid +from unittest.mock import patch + +import pytest +from faker import Faker +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.ops.entities.config_entity import TracingProviderEnum +from models.model import TraceAppConfig +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.ops_service import OpsService +from tests.test_containers_integration_tests.helpers import generate_valid_password + + +class TestOpsService: + @pytest.fixture + def mock_external_service_dependencies(self): + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + @pytest.fixture + def mock_ops_trace_manager(self): + with patch("services.ops_service.OpsTraceManager") as mock: + yield mock + + def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies): + fake = Faker() + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + }, + account, + ) + return app, account + + _SENTINEL = object() + + def _insert_trace_config( + self, + db_session: Session, + app_id: str, + provider: str, + tracing_config: dict | None | object = _SENTINEL, + ) -> TraceAppConfig: + trace_config = TraceAppConfig( + app_id=app_id, + tracing_provider=provider, + tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"}, + ) + db_session.add(trace_config) + db_session.commit() + return trace_config + + # ── get_tracing_app_config ───────────────────────────────────────── + + def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, "arize") + result = OpsService.get_tracing_app_config(fake_app_id, "arize") + assert result is None + + def test_get_tracing_app_config_none_config( + self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None) + + with pytest.raises(ValueError, match="Tracing config cannot be None."): + OpsService.get_tracing_app_config(app.id, "arize") + + @pytest.mark.parametrize( + ("provider", "default_url"), + [ + ("arize", "https://app.arize.com/"), + ("phoenix", "https://app.phoenix.arize.com/projects/"), + ("langsmith", "https://smith.langchain.com/"), + ("opik", "https://www.comet.com/opik/"), + ("weave", "https://wandb.ai/"), + ("aliyun", "https://arms.console.aliyun.com/"), + ("tencent", "https://console.cloud.tencent.com/apm"), + ("mlflow", "http://localhost:5000/"), + ("databricks", "https://www.databricks.com/"), + ], + ) + def test_get_tracing_app_config_providers_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {} + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == default_url + + @pytest.mark.parametrize( + "provider", + ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"], + ) + def test_get_tracing_app_config_providers_success( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"} + mock_otm.get_trace_config_project_url.return_value = "success_url" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == "success_url" + + def test_get_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.return_value = "key" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" + + def test_get_tracing_app_config_langfuse_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" + + # ── create_tracing_app_config ────────────────────────────────────── + + def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + assert result == {"error": "Invalid tracing provider: invalid_provider"} + + def test_create_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_ops_trace_manager + ): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + result = OpsService.create_tracing_app_config( + str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"} + ) + assert result == {"error": "Invalid Credentials"} + + @pytest.mark.parametrize( + ("provider", "config"), + [ + (TracingProviderEnum.ARIZE, {}), + (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), + (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), + (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), + ], + ) + def test_create_tracing_app_config_project_url_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config + ): + # Existing config causes the service to return None before reaching the DB insert + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(provider)) + + result = OpsService.create_tracing_app_config(app.id, provider, config) + + assert result is None + + def test_create_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_key.return_value = "key" + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config( + app.id, + TracingProviderEnum.LANGFUSE, + {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}, + ) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is None + + def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_create_tracing_app_config_with_empty_other_keys( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + # "project" is in other_keys for Arize; providing "" triggers default substitution + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("no url") + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""}) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.return_value = "http://project_url" + mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result == {"result": "success"} + + # ── update_tracing_app_config ────────────────────────────────────── + + def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): + OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + + def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE)) + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = False + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + with pytest.raises(ValueError, match="Invalid Credentials"): + OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + def test_update_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {"updated": "config"} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is not None + assert result["app_id"] == app.id + + # ── delete_tracing_app_config ────────────────────────────────────── + + def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session): + result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_delete_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize") + + result = OpsService.delete_tracing_app_config(app.id, "arize") + + assert result is True + remaining = db_session_with_containers.scalar( + select(TraceAppConfig) + .where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize") + .limit(1) + ) + assert remaining is None diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 4fe65d5803..7825f502f7 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -233,11 +233,10 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None - assert result.password is not None - assert result.password_salt is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None + assert refreshed.password is not None + assert refreshed.password_salt is not None def test_authenticate_account_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -414,9 +413,8 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None def test_get_user_through_email_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py index 161d0c41e8..514bbbe040 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_external.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -1,3 +1,4 @@ +from importlib import import_module from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -11,6 +12,7 @@ from controllers.console.datasets.external import ( BedrockRetrievalApi, ExternalApiTemplateApi, ExternalApiTemplateListApi, + ExternalApiUseCheckApi, ExternalDatasetCreateApi, ExternalKnowledgeHitTestingApi, ) @@ -19,6 +21,8 @@ from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService from services.knowledge_service import ExternalDatasetTestService +external_controller = import_module("controllers.console.datasets.external") + def unwrap(func): while hasattr(func, "__wrapped__"): @@ -44,10 +48,11 @@ def current_user(): @pytest.fixture(autouse=True) -def mock_auth(mocker, current_user): - mocker.patch( - "controllers.console.datasets.external.current_account_with_tenant", - return_value=(current_user, "tenant-1"), +def mock_auth(monkeypatch, current_user): + monkeypatch.setattr( + external_controller, + "current_account_with_tenant", + lambda: (current_user, "tenant-1"), ) @@ -136,6 +141,26 @@ class TestExternalApiTemplateApi: method(api, "api-id") +class TestExternalApiUseCheckApi: + def test_get_scopes_usage_check_to_current_tenant(self, app): + api = ExternalApiUseCheckApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + ExternalDatasetService, + "external_knowledge_api_use_check", + return_value=(True, 2), + ) as mock_use_check, + ): + response, status = method(api, "api-id") + + assert status == 200 + assert response == {"is_using": True, "count": 2} + mock_use_check.assert_called_once_with("api-id", "tenant-1") + + class TestExternalDatasetCreateApi: def test_create_success(self, app): api = ExternalDatasetCreateApi() diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 7f9fe9cbf9..dd643faac9 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -233,15 +233,20 @@ class TestCheckEmailUnique: def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): - session = MagicMock() + mock_session = MagicMock() first = MagicMock() first.scalar_one_or_none.return_value = None second = MagicMock() expected_account = MagicMock() second.scalar_one_or_none.return_value = expected_account - session.execute.side_effect = [first, second] + mock_session.execute.side_effect = [first, second] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) + mock_factory = MagicMock() + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + + with patch("services.account_service.session_factory", mock_factory): + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account - assert session.execute.call_count == 2 + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/core/mcp/test_entities.py b/api/tests/unit_tests/core/mcp/test_entities.py index 3fede55916..e99c38285c 100644 --- a/api/tests/unit_tests/core/mcp/test_entities.py +++ b/api/tests/unit_tests/core/mcp/test_entities.py @@ -4,9 +4,7 @@ from unittest.mock import Mock from core.mcp.entities import ( SUPPORTED_PROTOCOL_VERSIONS, - LifespanContextT, RequestContext, - SessionT, ) from core.mcp.session.base_session import BaseSession from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams @@ -198,42 +196,3 @@ class TestRequestContext: assert "RequestContext" in repr_str assert "test-123" in repr_str assert "MockSession" in repr_str - - -class TestTypeVariables: - """Test type variables defined in the module.""" - - def test_session_type_var(self): - """Test SessionT type variable.""" - - # Create a custom session class - class CustomSession(BaseSession): - pass - - # Use in generic context - def process_session(session: SessionT) -> SessionT: - return session - - mock_session = Mock(spec=CustomSession) - result = process_session(mock_session) - assert result == mock_session - - def test_lifespan_context_type_var(self): - """Test LifespanContextT type variable.""" - - # Use in generic context - def process_lifespan(context: LifespanContextT) -> LifespanContextT: - return context - - # Test with different types - str_context = "string-context" - assert process_lifespan(str_context) == str_context - - dict_context = {"key": "value"} - assert process_lifespan(dict_context) == dict_context - - class CustomContext: - pass - - custom_context = CustomContext() - assert process_lifespan(custom_context) == custom_context diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py index ca8cd5e514..43cdb4948d 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py @@ -39,6 +39,25 @@ class _FakeSession: return None +class _FakeBeginContext: + def __init__(self, session): + self._session = session + + def __enter__(self): + return self._session + + def __exit__(self, exc_type, exc, tb): + return None + + +def _patch_both(monkeypatch, module, session): + """Patch both Session and sessionmaker on the module.""" + monkeypatch.setattr(module, "Session", lambda _client: session) + monkeypatch.setattr( + module, "sessionmaker", lambda **kwargs: MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session))) + ) + + @pytest.fixture def relyt_module(monkeypatch): for name, module in _build_fake_relyt_modules().items(): @@ -108,13 +127,13 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch): monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1)) session = _FakeSession() - monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + _patch_both(monkeypatch, relyt_module, session) vector.create_collection(3) session.execute.assert_not_called() monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None)) session = _FakeSession() - monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + _patch_both(monkeypatch, relyt_module, session) vector.create_collection(3) executed_sql = [str(call.args[0]) for call in session.execute.call_args_list] assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql) @@ -265,15 +284,15 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module): # 8. delete commits session -def test_delete_commits_session(relyt_module, monkeypatch): +def test_delete_drops_table(relyt_module, monkeypatch): vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) vector._collection_name = "collection_1" vector.client = MagicMock() vector.embedding_dimension = 3 session = _FakeSession() - monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + _patch_both(monkeypatch, relyt_module, session) vector.delete() - session.commit.assert_called_once() + session.execute.assert_called_once() def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py index 951a920f3b..8e19a59af8 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py @@ -137,14 +137,15 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke session = MagicMock() - class _SessionCtx: + class _BeginCtx: def __enter__(self): return session def __exit__(self, exc_type, exc, tb): return False - monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx())) + monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm) vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) vector._collection_name = "collection_1" @@ -153,11 +154,9 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke vector._create_collection(3) - session.begin.assert_called_once() sql = str(session.execute.call_args.args[0]) assert "VECTOR(3)" in sql assert "VEC_L2_DISTANCE" in sql - session.commit.assert_called_once() tidb_module.redis_client.set.assert_called_once() @@ -396,23 +395,22 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch): def test_delete_drops_table(tidb_module, monkeypatch): session = MagicMock() session.execute.return_value = None - session.commit = MagicMock() - class _SessionCtx: + class _BeginCtx: def __enter__(self): return session def __exit__(self, exc_type, exc, tb): return False - monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx())) + monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm) vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) vector._collection_name = "collection_1" vector._engine = MagicMock() vector.delete() drop_sql = str(session.execute.call_args.args[0]) assert "DROP TABLE IF EXISTS collection_1" in drop_sql - session.commit.assert_called_once() def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch): diff --git a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py index f7475f2239..12e91f190f 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py @@ -39,7 +39,7 @@ class TestAppGenerateHandler: "root_node_id": None, } - arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs) + arguments = handler._extract_arguments(AppGenerateService.generate, **kwargs) assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate" assert "app_model" in arguments, "Handler uses app_model but parameter is missing" @@ -70,14 +70,11 @@ class TestAppGenerateHandler: handler.wrapper( tracer, dummy_func, - (), - { - "app_model": mock_app_model, - "user": mock_account_user, - "args": {"workflow_id": test_workflow_id}, - "invoke_from": InvokeFrom.DEBUGGER, - "streaming": False, - }, + app_model=mock_app_model, + user=mock_account_user, + args={"workflow_id": test_workflow_id}, + invoke_from=InvokeFrom.DEBUGGER, + streaming=False, ) spans = memory_span_exporter.get_finished_spans() diff --git a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py index 500f80fc3c..842e7f55e2 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py @@ -63,7 +63,7 @@ class TestWorkflowAppRunnerHandler: def runner_run(self): return "result" - handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {}) + handler.wrapper(tracer, runner_run, mock_workflow_runner) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 diff --git a/api/tests/unit_tests/extensions/otel/decorators/test_handler.py b/api/tests/unit_tests/extensions/otel/decorators/test_handler.py index 44788bab9a..bf861e3ef7 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/test_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/test_handler.py @@ -28,7 +28,7 @@ class TestSpanHandlerExtractArguments: args = (1, 2, 3) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -44,7 +44,7 @@ class TestSpanHandlerExtractArguments: args = () kwargs = {"a": 1, "b": 2, "c": 3} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -60,7 +60,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {"b": 2, "c": 3} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -76,7 +76,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -94,7 +94,7 @@ class TestSpanHandlerExtractArguments: instance = MyClass() args = (1, 2) kwargs = {} - result = handler._extract_arguments(instance.method, args, kwargs) + result = handler._extract_arguments(instance.method, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -109,7 +109,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is None @@ -122,11 +122,11 @@ class TestSpanHandlerExtractArguments: assert func not in handler._signature_cache - handler._extract_arguments(func, (1, 2), {}) + handler._extract_arguments(func, 1, 2) assert func in handler._signature_cache cached_sig = handler._signature_cache[func] - handler._extract_arguments(func, (3, 4), {}) + handler._extract_arguments(func, 3, 4) assert handler._signature_cache[func] is cached_sig @@ -142,7 +142,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - result = handler.wrapper(tracer, test_func, (), {}) + result = handler.wrapper(tracer, test_func) assert result == "result" spans = memory_span_exporter.get_finished_spans() @@ -159,7 +159,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -174,7 +174,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -190,7 +190,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError, match="test error"): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -208,7 +208,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -225,7 +225,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError, match="test error"): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter): @@ -236,7 +236,7 @@ class TestSpanHandlerWrapper: def test_func(a, b, c=10): return a + b + c - result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3}) + result = handler.wrapper(tracer, test_func, 1, 2, c=3) assert result == 6 @@ -249,7 +249,7 @@ class TestSpanHandlerWrapper: def my_function(x): return x * 2 - result = handler.wrapper(tracer, my_function, (5,), {}) + result = handler.wrapper(tracer, my_function, 5) assert result == 10 spans = memory_span_exporter.get_finished_spans() diff --git a/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py b/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py new file mode 100644 index 0000000000..7087490845 --- /dev/null +++ b/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py @@ -0,0 +1,138 @@ +import json + +from libs.pyrefly_type_coverage import ( + CoverageSummary, + format_comparison_markdown, + format_summary_markdown, + parse_summary, +) + + +def _make_report(summary: dict) -> str: + return json.dumps({"module_reports": [], "summary": summary}) + + +_SAMPLE_SUMMARY: dict = { + "n_modules": 100, + "n_typable": 1000, + "n_typed": 400, + "n_any": 50, + "n_untyped": 550, + "coverage": 45.0, + "strict_coverage": 40.0, + "n_functions": 200, + "n_methods": 300, + "n_function_params": 150, + "n_method_params": 250, + "n_classes": 80, + "n_attrs": 40, + "n_properties": 20, + "n_type_ignores": 10, +} + + +def _make_summary( + *, + n_modules: int = 100, + n_typable: int = 1000, + n_typed: int = 400, + n_any: int = 50, + n_untyped: int = 550, + coverage: float = 45.0, + strict_coverage: float = 40.0, +) -> CoverageSummary: + return { + "n_modules": n_modules, + "n_typable": n_typable, + "n_typed": n_typed, + "n_any": n_any, + "n_untyped": n_untyped, + "coverage": coverage, + "strict_coverage": strict_coverage, + } + + +def test_parse_summary_extracts_fields() -> None: + report_json = _make_report(_SAMPLE_SUMMARY) + + result = parse_summary(report_json) + + assert result["n_modules"] == 100 + assert result["n_typable"] == 1000 + assert result["n_typed"] == 400 + assert result["n_any"] == 50 + assert result["n_untyped"] == 550 + assert result["coverage"] == 45.0 + assert result["strict_coverage"] == 40.0 + + +def test_parse_summary_handles_empty_input() -> None: + assert parse_summary("")["n_modules"] == 0 + assert parse_summary(" ")["n_modules"] == 0 + + +def test_parse_summary_handles_invalid_json() -> None: + assert parse_summary("not json")["n_modules"] == 0 + + +def test_parse_summary_handles_missing_summary_key() -> None: + assert parse_summary(json.dumps({"other": 1}))["n_modules"] == 0 + + +def test_parse_summary_handles_incomplete_summary() -> None: + partial = json.dumps({"summary": {"n_modules": 5}}) + assert parse_summary(partial)["n_modules"] == 0 + + +def test_format_summary_markdown_contains_key_metrics() -> None: + summary = _make_summary() + + result = format_summary_markdown(summary) + + assert "**Type coverage**" in result + assert "45.00%" in result + assert "40.00%" in result + assert "| Modules | 100 |" in result + + +def test_format_comparison_markdown_shows_positive_delta() -> None: + base = _make_summary() + pr = _make_summary( + n_modules=101, + n_typable=1010, + n_typed=420, + n_untyped=540, + coverage=46.53, + strict_coverage=41.58, + ) + + result = format_comparison_markdown(base, pr) + + assert "| Base | PR | Delta |" in result + assert "+1.53%" in result + assert "+1.58%" in result + assert "+20" in result + + +def test_format_comparison_markdown_shows_negative_delta() -> None: + base = _make_summary() + pr = _make_summary( + n_typed=390, + n_any=60, + coverage=44.0, + strict_coverage=39.0, + ) + + result = format_comparison_markdown(base, pr) + + assert "-1.00%" in result + assert "-10" in result + + +def test_format_comparison_markdown_shows_zero_delta() -> None: + summary = _make_summary() + + result = format_comparison_markdown(summary, summary) + + assert "0.00%" in result + assert "| 0 |" in result diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py index 5848603ab8..dd41c0c97e 100644 --- a/api/tests/unit_tests/services/external_dataset_service.py +++ b/api/tests/unit_tests/services/external_dataset_service.py @@ -396,10 +396,11 @@ class TestExternalDatasetServiceUsageAndBindings: mock_db_session.scalar.return_value = 3 - in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1") + in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1") assert in_use is True assert count == 3 + assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0]) def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock): """ @@ -408,7 +409,7 @@ class TestExternalDatasetServiceUsageAndBindings: mock_db_session.scalar.return_value = 0 - in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1") + in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1") assert in_use is False assert count == 0 diff --git a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py index 20f132c015..53a9e6210c 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py @@ -6,23 +6,25 @@ MODULE = "services.plugin.plugin_permission_service" def _patched_session(): - """Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" + """Patch session_factory.create_session() to return a mock session as context manager.""" session = MagicMock() - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) - mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) - patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) - db_patcher = patch(f"{MODULE}.db") - return patcher, db_patcher, session + session.__enter__ = MagicMock(return_value=session) + session.__exit__ = MagicMock(return_value=False) + session.begin.return_value.__enter__ = MagicMock(return_value=session) + session.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_factory = MagicMock() + mock_factory.create_session.return_value = session + patcher = patch(f"{MODULE}.session_factory", mock_factory) + return patcher, session class TestGetPermission: def test_returns_permission_when_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() permission = MagicMock() session.scalar.return_value = permission - with p1, p2: + with p1: from services.plugin.plugin_permission_service import PluginPermissionService result = PluginPermissionService.get_permission("t1") @@ -30,10 +32,10 @@ class TestGetPermission: assert result is permission def test_returns_none_when_not_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2: + with p1: from services.plugin.plugin_permission_service import PluginPermissionService result = PluginPermissionService.get_permission("t1") @@ -43,10 +45,10 @@ class TestGetPermission: class TestChangePermission: def test_creates_new_permission_when_not_exists(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls: perm_cls.return_value = MagicMock() from services.plugin.plugin_permission_service import PluginPermissionService @@ -54,20 +56,24 @@ class TestChangePermission: "t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE ) + assert result is True + session.begin.assert_called_once() session.add.assert_called_once() def test_updates_existing_permission(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() session.scalar.return_value = existing - with p1, p2: + with p1: from services.plugin.plugin_permission_service import PluginPermissionService result = PluginPermissionService.change_permission( "t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS ) + assert result is True + session.begin.assert_called_once() assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS session.add.assert_not_called() diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index d15074e7a6..eeb5d178ec 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1427,16 +1427,7 @@ class TestRegisterService: mock_tenant.name = "Test Workspace" mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") - # Mock database queries - need to mock the sessionmaker query - mock_session = MagicMock() - mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account - - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = None @@ -1475,7 +1466,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session) + mock_lookup.assert_called_once_with("newuser@example.com") def test_invite_new_member_normalizes_new_account_email( self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies @@ -1486,13 +1477,7 @@ class TestRegisterService: mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") mixed_email = "Invitee@Example.com" - mock_session = MagicMock() - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = None @@ -1525,7 +1510,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with(mixed_email, session=mock_session) + mock_lookup.assert_called_once_with(mixed_email) mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal") mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) @@ -1545,16 +1530,7 @@ class TestRegisterService: account_id="existing-user-456", email="existing@example.com", status="pending" ) - # Mock database queries - need to mock the sessionmaker query - mock_session = MagicMock() - mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account - - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = mock_existing_account @@ -1584,7 +1560,7 @@ class TestRegisterService: mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal") mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account) mock_task_dependencies.delay.assert_called_once() - mock_lookup.assert_called_once_with("existing@example.com", session=mock_session) + mock_lookup.assert_called_once_with("existing@example.com") def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies): """Test inviting a member who is already in the tenant.""" diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py index e5a2541da7..9b4734b7ad 100644 --- a/api/tests/unit_tests/services/test_dataset_service_document.py +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -1069,6 +1069,33 @@ class TestDocumentServiceCreateValidation: assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1 assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False + def test_process_rule_args_validate_hierarchical_defaults_parent_mode_to_paragraph(self): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="hierarchical", + rules=Rule( + pre_processing_rules=[ + PreProcessingRule(id="remove_extra_spaces", enabled=True), + ], + segmentation=Segmentation(separator="\n", max_tokens=1024), + subchunk_segmentation=Segmentation(separator="\n", max_tokens=512), + ), + ), + ) + + DocumentService.process_rule_args_validate(knowledge_config) + + assert knowledge_config.process_rule is not None + assert knowledge_config.process_rule.rules is not None + assert knowledge_config.process_rule.rules.parent_mode == "paragraph" + class TestDocumentServiceSaveDocumentWithDatasetId: """Unit tests for non-SQL validation branches in save_document_with_dataset_id.""" diff --git a/api/tests/unit_tests/services/test_external_dataset_service.py b/api/tests/unit_tests/services/test_external_dataset_service.py index 7c8dab5029..0777e6a8a4 100644 --- a/api/tests/unit_tests/services/test_external_dataset_service.py +++ b/api/tests/unit_tests/services/test_external_dataset_service.py @@ -974,26 +974,29 @@ class TestExternalDatasetServiceAPIUseCheck: """Test API use check when API has one binding.""" # Arrange api_id = "api-123" + tenant_id = "tenant-123" mock_db.session.scalar.return_value = 1 # Act - in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id) + in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id) # Assert assert in_use is True assert count == 1 + assert "tenant_id" in str(mock_db.session.scalar.call_args.args[0]) @patch("services.external_knowledge_service.db") def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory): """Test API use check with multiple bindings.""" # Arrange api_id = "api-123" + tenant_id = "tenant-123" mock_db.session.scalar.return_value = 10 # Act - in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id) + in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id) # Assert assert in_use is True @@ -1004,11 +1007,12 @@ class TestExternalDatasetServiceAPIUseCheck: """Test API use check when API is not in use.""" # Arrange api_id = "api-123" + tenant_id = "tenant-123" mock_db.session.scalar.return_value = 0 # Act - in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id) + in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id) # Assert assert in_use is False diff --git a/api/tests/unit_tests/services/test_ops_service.py b/api/tests/unit_tests/services/test_ops_service.py deleted file mode 100644 index 7067e3b3dd..0000000000 --- a/api/tests/unit_tests/services/test_ops_service.py +++ /dev/null @@ -1,392 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.ops.entities.config_entity import TracingProviderEnum -from models.model import App, TraceAppConfig -from services.ops_service import OpsService - - -class TestOpsService: - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): - # Arrange - mock_db.session.scalar.return_value = None - - # Act - result = OpsService.get_tracing_app_config("app_id", "arize") - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = None - - # Act - result = OpsService.get_tracing_app_config("app_id", "arize") - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = None - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - # Act & Assert - with pytest.raises(ValueError, match="Tracing config cannot be None."): - OpsService.get_tracing_app_config("app_id", "arize") - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - @pytest.mark.parametrize( - ("provider", "default_url"), - [ - ("arize", "https://app.arize.com/"), - ("phoenix", "https://app.phoenix.arize.com/projects/"), - ("langsmith", "https://smith.langchain.com/"), - ("opik", "https://www.comet.com/opik/"), - ("weave", "https://wandb.ai/"), - ("aliyun", "https://arms.console.aliyun.com/"), - ("tencent", "https://console.cloud.tencent.com/apm"), - ("mlflow", "http://localhost:5000/"), - ("databricks", "https://www.databricks.com/"), - ], - ) - def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} - mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") - mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") - - # Act - result = OpsService.get_tracing_app_config("app_id", provider) - - # Assert - assert result["tracing_config"]["project_url"] == default_url - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - @pytest.mark.parametrize( - "provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"] - ) - def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} - mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url" - - # Act - result = OpsService.get_tracing_app_config("app_id", provider) - - # Assert - assert result["tracing_config"]["project_url"] == "success_url" - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" - - # Act - result = OpsService.get_tracing_app_config("app_id", "langfuse") - - # Assert - assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") - - # Act - result = OpsService.get_tracing_app_config("app_id", "langfuse") - - # Assert - assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): - # Act - result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {}) - - # Assert - assert result == {"error": "Invalid tracing provider: invalid_provider"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.LANGFUSE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = False - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"}) - - # Assert - assert result == {"error": "Invalid Credentials"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - @pytest.mark.parametrize( - ("provider", "config"), - [ - (TracingProviderEnum.ARIZE, {}), - (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), - (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), - (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), - ], - ) - def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config): - # Arrange - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") - mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") - mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig) - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, config) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.LANGFUSE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = app - mock_ops_trace_manager.encrypt_tracing_config.return_value = {} - - # Act - result = OpsService.create_tracing_app_config( - "app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"} - ) - - # Assert - assert result == {"result": "success"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig) - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = None - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = app - mock_ops_trace_manager.encrypt_tracing_config.return_value = {} - - # Act - # 'project' is in other_keys for Arize - # provide an empty string for the project in the tracing_config - # create_tracing_app_config will replace it with the default from the model - result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""}) - - # Assert - assert result == {"result": "success"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url" - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = app - mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"} - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {}) - - # Assert - assert result == {"result": "success"} - mock_db.session.add.assert_called() - mock_db.session.commit.assert_called() - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): - # Act & Assert - with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): - OpsService.update_tracing_app_config("app_id", "invalid_provider", {}) - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_db.session.scalar.return_value = None - - # Act - result = OpsService.update_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - current_config = MagicMock(spec=TraceAppConfig) - mock_db.session.scalar.return_value = current_config - mock_db.session.get.return_value = None - - # Act - result = OpsService.update_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - current_config = MagicMock(spec=TraceAppConfig) - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = current_config - mock_db.session.get.return_value = app - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.check_trace_config_is_effective.return_value = False - - # Act & Assert - with pytest.raises(ValueError, match="Invalid Credentials"): - OpsService.update_tracing_app_config("app_id", provider, {}) - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - current_config = MagicMock(spec=TraceAppConfig) - current_config.to_dict.return_value = {"some": "data"} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = current_config - mock_db.session.get.return_value = app - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - - # Act - result = OpsService.update_tracing_app_config("app_id", provider, {}) - - # Assert - assert result == {"some": "data"} - mock_db.session.commit.assert_called_once() - - @patch("services.ops_service.db") - def test_delete_tracing_app_config_no_config(self, mock_db): - # Arrange - mock_db.session.scalar.return_value = None - - # Act - result = OpsService.delete_tracing_app_config("app_id", "arize") - - # Assert - assert result is None - - @patch("services.ops_service.db") - def test_delete_tracing_app_config_success(self, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - mock_db.session.scalar.return_value = trace_config - - # Act - result = OpsService.delete_tracing_app_config("app_id", "arize") - - # Assert - assert result is True - mock_db.session.delete.assert_called_with(trace_config) - mock_db.session.commit.assert_called_once() diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index ccd2dd53cc..d462ca6449 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -2,7 +2,7 @@ import type { Area } from 'react-easy-crop' import type { OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput' -import type { AvatarProps } from '@/app/components/base/avatar' +import type { AvatarProps } from '@/app/components/base/ui/avatar' import type { ImageFile } from '@/types/app' import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react' import * as React from 'react' @@ -10,10 +10,10 @@ import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import ImageInput from '@/app/components/base/app-icon-picker/ImageInput' import getCroppedImg from '@/app/components/base/app-icon-picker/utils' -import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks' +import { Avatar } from '@/app/components/base/ui/avatar' import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' import { toast } from '@/app/components/base/ui/toast' import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx index 36a510cf63..b81c96df74 100644 --- a/web/app/account/(commonLayout)/avatar.tsx +++ b/web/app/account/(commonLayout)/avatar.tsx @@ -6,9 +6,9 @@ import { import { Fragment } from 'react' import { useTranslation } from 'react-i18next' import { resetUser } from '@/app/components/base/amplitude/utils' -import { Avatar } from '@/app/components/base/avatar' import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' import PremiumBadge from '@/app/components/base/premium-badge' +import { Avatar } from '@/app/components/base/ui/avatar' import { useProviderContext } from '@/context/provider-context' import { useRouter } from '@/next/navigation' import { useLogout, useUserProfile } from '@/service/use-common' diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index 2c849fd542..b0b7f557a4 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -10,9 +10,9 @@ import { import * as React from 'react' import { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' -import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' +import { Avatar } from '@/app/components/base/ui/avatar' import { toast } from '@/app/components/base/ui/toast' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index d7e48f2d1f..b25fb94191 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -5,12 +5,12 @@ import { RiAddCircleFill, RiArrowRightSLine, RiOrganizationChart } from '@remixi import { useDebounce } from 'ahooks' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import { Avatar } from '@/app/components/base/ui/avatar' import { useSelector } from '@/context/app-context' import { SubjectType } from '@/models/access-control' import { useSearchForWhiteListCandidates } from '@/service/access-control' import { cn } from '@/utils/classnames' import useAccessControlStore from '../../../../context/access-control-store' -import { Avatar } from '../../base/avatar' import Button from '../../base/button' import Checkbox from '../../base/checkbox' import Input from '../../base/input' diff --git a/web/app/components/app/app-access-control/specific-groups-or-members.tsx b/web/app/components/app/app-access-control/specific-groups-or-members.tsx index 2c0e4b2694..8f4e71c8d2 100644 --- a/web/app/components/app/app-access-control/specific-groups-or-members.tsx +++ b/web/app/components/app/app-access-control/specific-groups-or-members.tsx @@ -3,10 +3,10 @@ import type { AccessControlAccount, AccessControlGroup } from '@/models/access-c import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from '@remixicon/react' import { useCallback, useEffect } from 'react' import { useTranslation } from 'react-i18next' +import { Avatar } from '@/app/components/base/ui/avatar' import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects } from '@/service/access-control' import useAccessControlStore from '../../../../context/access-control-store' -import { Avatar } from '../../base/avatar' import Loading from '../../base/loading' import Tooltip from '../../base/tooltip' import AddMemberOrGroupDialog from './add-member-or-group-pop' diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/__tests__/chat-item.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/__tests__/chat-item.spec.tsx index 80bb26a052..61eb8f2ae8 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/__tests__/chat-item.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/__tests__/chat-item.spec.tsx @@ -90,7 +90,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({ }, })) -vi.mock('@/app/components/base/avatar', () => ({ +vi.mock('@/app/components/base/ui/avatar', () => ({ Avatar: ({ name }: { name: string }) =>
{name}
, })) diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx index e957fc24c4..56345890ff 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx @@ -7,11 +7,11 @@ import { useCallback, useMemo, } from 'react' -import { Avatar } from '@/app/components/base/avatar' import Chat from '@/app/components/base/chat/chat' import { useChat } from '@/app/components/base/chat/chat/hooks' import { getLastAnswer } from '@/app/components/base/chat/utils' import { useFeatures } from '@/app/components/base/features/hooks' +import { Avatar } from '@/app/components/base/ui/avatar' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useAppContext } from '@/context/app-context' import { useDebugConfigurationContext } from '@/context/debug-configuration' diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx index 84ff8b5ede..a9f9f1116b 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx @@ -3,11 +3,11 @@ import type { ChatConfig, ChatItem, OnSend } from '@/app/components/base/chat/ty import type { FileEntity } from '@/app/components/base/file-uploader/types' import { memo, useCallback, useImperativeHandle, useMemo } from 'react' import { useStore as useAppStore } from '@/app/components/app/store' -import { Avatar } from '@/app/components/base/avatar' import Chat from '@/app/components/base/chat/chat' import { useChat } from '@/app/components/base/chat/chat/hooks' import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils' import { useFeatures } from '@/app/components/base/features/hooks' +import { Avatar } from '@/app/components/base/ui/avatar' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useAppContext } from '@/context/app-context' import { useDebugConfigurationContext } from '@/context/debug-configuration' diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 720355d09f..9c3a7cc8f7 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -11,6 +11,7 @@ import AppIcon from '@/app/components/base/app-icon' import InputsForm from '@/app/components/base/chat/chat-with-history/inputs-form' import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested-questions' import { Markdown } from '@/app/components/base/markdown' +import { Avatar } from '@/app/components/base/ui/avatar' import { InputVarType } from '@/app/components/workflow/types' import { AppSourceType, @@ -23,7 +24,6 @@ import { submitHumanInputForm as submitHumanInputFormService } from '@/service/w import { TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' import { formatBooleanInputs } from '@/utils/model-config' -import { Avatar } from '../../avatar' import Chat from '../chat' import { useChat } from '../chat/hooks' import { getLastAnswer, isValidGeneratedAnswer } from '../utils' diff --git a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx index c518a9d078..451f566505 100644 --- a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx +++ b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx @@ -12,6 +12,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested import InputsForm from '@/app/components/base/chat/embedded-chatbot/inputs-form' import LogoAvatar from '@/app/components/base/logo/logo-embedded-chat-avatar' import { Markdown } from '@/app/components/base/markdown' +import { Avatar } from '@/app/components/base/ui/avatar' import { InputVarType } from '@/app/components/workflow/types' import { AppSourceType, @@ -23,7 +24,6 @@ import { import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow' import { TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' -import { Avatar } from '../../avatar' import Chat from '../chat' import { useChat } from '../chat/hooks' import { getLastAnswer, isValidGeneratedAnswer } from '../utils' diff --git a/web/app/components/base/avatar/__tests__/index.spec.tsx b/web/app/components/base/ui/avatar/__tests__/index.spec.tsx similarity index 99% rename from web/app/components/base/avatar/__tests__/index.spec.tsx rename to web/app/components/base/ui/avatar/__tests__/index.spec.tsx index 69c56ac993..8be3f8bf0f 100644 --- a/web/app/components/base/avatar/__tests__/index.spec.tsx +++ b/web/app/components/base/ui/avatar/__tests__/index.spec.tsx @@ -1,5 +1,5 @@ import { render, screen } from '@testing-library/react' -import { Avatar } from '../index' +import { Avatar } from '..' describe('Avatar', () => { describe('Rendering', () => { diff --git a/web/app/components/base/avatar/index.stories.tsx b/web/app/components/base/ui/avatar/index.stories.tsx similarity index 100% rename from web/app/components/base/avatar/index.stories.tsx rename to web/app/components/base/ui/avatar/index.stories.tsx diff --git a/web/app/components/base/avatar/index.tsx b/web/app/components/base/ui/avatar/index.tsx similarity index 93% rename from web/app/components/base/avatar/index.tsx rename to web/app/components/base/ui/avatar/index.tsx index dac8988e27..bc00592db7 100644 --- a/web/app/components/base/avatar/index.tsx +++ b/web/app/components/base/ui/avatar/index.tsx @@ -53,8 +53,8 @@ function AvatarRoot({ return ( ) diff --git a/web/app/components/datasets/settings/permission-selector/index.tsx b/web/app/components/datasets/settings/permission-selector/index.tsx index a83beffbb4..cdf13e4b32 100644 --- a/web/app/components/datasets/settings/permission-selector/index.tsx +++ b/web/app/components/datasets/settings/permission-selector/index.tsx @@ -4,13 +4,13 @@ import { useDebounceFn } from 'ahooks' import * as React from 'react' import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { Avatar } from '@/app/components/base/avatar' import Input from '@/app/components/base/input' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' +import { Avatar } from '@/app/components/base/ui/avatar' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { DatasetPermission } from '@/models/datasets' import { cn } from '@/utils/classnames' diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index f5b0352a40..442554615b 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -4,9 +4,9 @@ import type { MouseEventHandler, ReactNode } from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { resetUser } from '@/app/components/base/amplitude/utils' -import { Avatar } from '@/app/components/base/avatar' import PremiumBadge from '@/app/components/base/premium-badge' import ThemeSwitcher from '@/app/components/base/theme-switcher' +import { Avatar } from '@/app/components/base/ui/avatar' import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { IS_CLOUD_EDITION } from '@/config' diff --git a/web/app/components/header/account-setting/members-page/index.tsx b/web/app/components/header/account-setting/members-page/index.tsx index 875ffba3e0..6ac9ee5d2d 100644 --- a/web/app/components/header/account-setting/members-page/index.tsx +++ b/web/app/components/header/account-setting/members-page/index.tsx @@ -2,7 +2,7 @@ import type { InvitationResult } from '@/models/common' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { Avatar } from '@/app/components/base/avatar' +import { Avatar } from '@/app/components/base/ui/avatar' import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' import { NUM_INFINITE } from '@/app/components/billing/config' import { Plan } from '@/app/components/billing/type' diff --git a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/member-selector.tsx b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/member-selector.tsx index a6617ac4d2..59e69b92e2 100644 --- a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/member-selector.tsx +++ b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/member-selector.tsx @@ -3,9 +3,9 @@ import type { FC } from 'react' import * as React from 'react' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { Avatar } from '@/app/components/base/avatar' import Input from '@/app/components/base/input' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +import { Avatar } from '@/app/components/base/ui/avatar' import { useMembers } from '@/service/use-common' import { cn } from '@/utils/classnames' diff --git a/web/app/components/workflow/__tests__/block-icon.spec.tsx b/web/app/components/workflow/__tests__/block-icon.spec.tsx new file mode 100644 index 0000000000..c3b30a67b6 --- /dev/null +++ b/web/app/components/workflow/__tests__/block-icon.spec.tsx @@ -0,0 +1,46 @@ +import { render } from '@testing-library/react' +import { API_PREFIX } from '@/config' +import BlockIcon, { VarBlockIcon } from '../block-icon' +import { BlockEnum } from '../types' + +describe('BlockIcon', () => { + it('renders the default workflow icon container for regular nodes', () => { + const { container } = render() + + const iconContainer = container.firstElementChild + expect(iconContainer).toHaveClass('w-4', 'h-4', 'bg-util-colors-blue-brand-blue-brand-500', 'extra-class') + expect(iconContainer?.querySelector('svg')).toBeInTheDocument() + }) + + it('normalizes protected plugin icon urls for tool-like nodes', () => { + const { container } = render( + , + ) + + const iconContainer = container.firstElementChild as HTMLElement + const backgroundIcon = iconContainer.querySelector('div') as HTMLElement + + expect(iconContainer).not.toHaveClass('bg-util-colors-blue-blue-500') + expect(backgroundIcon.style.backgroundImage).toContain( + `${API_PREFIX}/workspaces/current/plugin/icon/plugin-tool.png`, + ) + }) +}) + +describe('VarBlockIcon', () => { + it('renders the compact icon variant without the default container wrapper', () => { + const { container } = render( + , + ) + + expect(container.querySelector('.custom-var-icon')).toBeInTheDocument() + expect(container.querySelector('svg')).toBeInTheDocument() + expect(container.querySelector('.bg-util-colors-warning-warning-500')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/__tests__/context.spec.tsx b/web/app/components/workflow/__tests__/context.spec.tsx new file mode 100644 index 0000000000..ccf1eaa9b1 --- /dev/null +++ b/web/app/components/workflow/__tests__/context.spec.tsx @@ -0,0 +1,39 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { WorkflowContextProvider } from '../context' +import { useStore, useWorkflowStore } from '../store' + +const StoreConsumer = () => { + const showSingleRunPanel = useStore(s => s.showSingleRunPanel) + const store = useWorkflowStore() + + return ( + + ) +} + +describe('WorkflowContextProvider', () => { + it('provides the workflow store to descendants and keeps the same store across rerenders', async () => { + const user = userEvent.setup() + const { rerender } = render( + + + , + ) + + expect(screen.getByRole('button', { name: 'closed' })).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'closed' })) + expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument() + + rerender( + + + , + ) + + expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/__tests__/index.spec.tsx b/web/app/components/workflow/__tests__/index.spec.tsx new file mode 100644 index 0000000000..77b61e54e7 --- /dev/null +++ b/web/app/components/workflow/__tests__/index.spec.tsx @@ -0,0 +1,67 @@ +import type { Edge, Node } from '../types' +import { render, screen } from '@testing-library/react' +import { useStoreApi } from 'reactflow' +import { useDatasetsDetailStore } from '../datasets-detail-store/store' +import WorkflowWithDefaultContext from '../index' +import { BlockEnum } from '../types' +import { useWorkflowHistoryStore } from '../workflow-history-store' + +const nodes: Node[] = [ + { + id: 'node-start', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + title: 'Start', + desc: '', + type: BlockEnum.Start, + }, + }, +] + +const edges: Edge[] = [ + { + id: 'edge-1', + source: 'node-start', + target: 'node-end', + sourceHandle: null, + targetHandle: null, + type: 'custom', + data: { + sourceType: BlockEnum.Start, + targetType: BlockEnum.End, + }, + }, +] + +const ContextConsumer = () => { + const { store, shortcutsEnabled } = useWorkflowHistoryStore() + const datasetCount = useDatasetsDetailStore(state => Object.keys(state.datasetsDetail).length) + const reactFlowStore = useStoreApi() + + return ( +
+ {`history:${store.getState().nodes.length}`} + {` shortcuts:${String(shortcutsEnabled)}`} + {` datasets:${datasetCount}`} + {` reactflow:${String(!!reactFlowStore)}`} +
+ ) +} + +describe('WorkflowWithDefaultContext', () => { + it('wires the ReactFlow, workflow history, and datasets detail providers around its children', () => { + render( + + + , + ) + + expect( + screen.getByText('history:1 shortcuts:true datasets:0 reactflow:true'), + ).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/__tests__/shortcuts-name.spec.tsx b/web/app/components/workflow/__tests__/shortcuts-name.spec.tsx new file mode 100644 index 0000000000..87efddb005 --- /dev/null +++ b/web/app/components/workflow/__tests__/shortcuts-name.spec.tsx @@ -0,0 +1,51 @@ +import { render, screen } from '@testing-library/react' +import ShortcutsName from '../shortcuts-name' + +describe('ShortcutsName', () => { + const originalNavigator = globalThis.navigator + + afterEach(() => { + Object.defineProperty(globalThis, 'navigator', { + value: originalNavigator, + writable: true, + configurable: true, + }) + }) + + it('renders mac-friendly key labels and style variants', () => { + Object.defineProperty(globalThis, 'navigator', { + value: { userAgent: 'Macintosh' }, + writable: true, + configurable: true, + }) + + const { container } = render( + , + ) + + expect(screen.getByText('⌘')).toBeInTheDocument() + expect(screen.getByText('⇧')).toBeInTheDocument() + expect(screen.getByText('s')).toBeInTheDocument() + expect(container.querySelector('.system-kbd')).toHaveClass( + 'bg-components-kbd-bg-white', + 'text-text-tertiary', + ) + }) + + it('keeps raw key names on non-mac systems', () => { + Object.defineProperty(globalThis, 'navigator', { + value: { userAgent: 'Windows NT' }, + writable: true, + configurable: true, + }) + + render() + + expect(screen.getByText('ctrl')).toBeInTheDocument() + expect(screen.getByText('alt')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/__tests__/workflow-history-store.spec.tsx b/web/app/components/workflow/__tests__/workflow-history-store.spec.tsx new file mode 100644 index 0000000000..931cd97c02 --- /dev/null +++ b/web/app/components/workflow/__tests__/workflow-history-store.spec.tsx @@ -0,0 +1,97 @@ +import type { Edge, Node } from '../types' +import type { WorkflowHistoryState } from '../workflow-history-store' +import { render, renderHook, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { BlockEnum } from '../types' +import { useWorkflowHistoryStore, WorkflowHistoryProvider } from '../workflow-history-store' + +const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + title: 'Start', + desc: '', + type: BlockEnum.Start, + selected: true, + }, + selected: true, + }, +] + +const edges: Edge[] = [ + { + id: 'edge-1', + source: 'node-1', + target: 'node-2', + sourceHandle: null, + targetHandle: null, + type: 'custom', + selected: true, + data: { + sourceType: BlockEnum.Start, + targetType: BlockEnum.End, + }, + }, +] + +const HistoryConsumer = () => { + const { store, shortcutsEnabled, setShortcutsEnabled } = useWorkflowHistoryStore() + + return ( + + ) +} + +describe('WorkflowHistoryProvider', () => { + it('provides workflow history state and shortcut toggles', async () => { + const user = userEvent.setup() + + render( + + + , + ) + + expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' })).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' })) + expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:false' })).toBeInTheDocument() + }) + + it('sanitizes selected flags when history state is replaced through the exposed store api', () => { + const wrapper = ({ children }: { children: React.ReactNode }) => ( + + {children} + + ) + + const { result } = renderHook(() => useWorkflowHistoryStore(), { wrapper }) + const nextState: WorkflowHistoryState = { + workflowHistoryEvent: undefined, + workflowHistoryEventMeta: undefined, + nodes, + edges, + } + + result.current.store.setState(nextState) + + expect(result.current.store.getState().nodes[0].data.selected).toBe(false) + expect(result.current.store.getState().edges[0].selected).toBe(false) + }) + + it('throws when consumed outside the provider', () => { + expect(() => renderHook(() => useWorkflowHistoryStore())).toThrow( + 'useWorkflowHistoryStoreApi must be used within a WorkflowHistoryProvider', + ) + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/all-tools.spec.tsx b/web/app/components/workflow/block-selector/__tests__/all-tools.spec.tsx new file mode 100644 index 0000000000..64f012fae3 --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/all-tools.spec.tsx @@ -0,0 +1,140 @@ +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { useMarketplacePlugins } from '@/app/components/plugins/marketplace/hooks' +import { useGlobalPublicStore } from '@/context/global-public-context' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import AllTools from '../all-tools' +import { createGlobalPublicStoreState, createToolProvider } from './factories' + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: vi.fn(), +})) + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/plugins/marketplace/hooks', () => ({ + useMarketplacePlugins: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +vi.mock('@/utils/var', async importOriginal => ({ + ...(await importOriginal()), + getMarketplaceUrl: () => 'https://marketplace.test/tools', +})) + +const mockUseMarketplacePlugins = vi.mocked(useMarketplacePlugins) +const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore) +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +const createMarketplacePluginsMock = () => ({ + plugins: [], + total: 0, + resetPlugins: vi.fn(), + queryPlugins: vi.fn(), + queryPluginsWithDebounced: vi.fn(), + cancelQueryPluginsWithDebounced: vi.fn(), + isLoading: false, + isFetchingNextPage: false, + hasNextPage: false, + fetchNextPage: vi.fn(), + page: 0, +}) + +describe('AllTools', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGlobalPublicStore.mockImplementation(selector => selector(createGlobalPublicStoreState(false))) + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + mockUseMarketplacePlugins.mockReturnValue(createMarketplacePluginsMock()) + }) + + it('filters tools by the active tab', async () => { + const user = userEvent.setup() + + render( + , + ) + + expect(screen.getByText('Built In Provider')).toBeInTheDocument() + expect(screen.getByText('Custom Provider')).toBeInTheDocument() + + await user.click(screen.getByText('workflow.tabs.customTool')) + + expect(screen.getByText('Custom Provider')).toBeInTheDocument() + expect(screen.queryByText('Built In Provider')).not.toBeInTheDocument() + }) + + it('filters the rendered tools by the search text', () => { + render( + , + ) + + expect(screen.getByText('Report Toolkit')).toBeInTheDocument() + expect(screen.queryByText('Other Toolkit')).not.toBeInTheDocument() + }) + + it('shows the empty state when no tool matches the current filter', async () => { + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('workflow.tabs.noPluginsFound')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/blocks.spec.tsx b/web/app/components/workflow/block-selector/__tests__/blocks.spec.tsx new file mode 100644 index 0000000000..00972f808c --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/blocks.spec.tsx @@ -0,0 +1,79 @@ +import type { NodeDefault } from '../../types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { BlockEnum } from '../../types' +import Blocks from '../blocks' +import { BlockClassificationEnum } from '../types' + +const runtimeState = vi.hoisted(() => ({ + nodes: [] as Array<{ data: { type?: BlockEnum } }>, +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: () => runtimeState.nodes, + }), + }), +})) + +const createBlock = (type: BlockEnum, title: string, classification = BlockClassificationEnum.Default): NodeDefault => ({ + metaData: { + classification, + sort: 0, + type, + title, + author: 'Dify', + description: `${title} description`, + }, + defaultValue: {}, + checkValid: () => ({ isValid: true }), +}) + +describe('Blocks', () => { + beforeEach(() => { + runtimeState.nodes = [] + }) + + it('renders grouped blocks, filters duplicate knowledge-base nodes, and selects a block', async () => { + const user = userEvent.setup() + const onSelect = vi.fn() + + runtimeState.nodes = [{ data: { type: BlockEnum.KnowledgeBase } }] + + render( + , + ) + + expect(screen.getByText('LLM')).toBeInTheDocument() + expect(screen.getByText('Exit Loop')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.loop.loopNode')).toBeInTheDocument() + expect(screen.queryByText('Knowledge Retrieval')).not.toBeInTheDocument() + + await user.click(screen.getByText('LLM')) + + expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM) + }) + + it('shows the empty state when no block matches the search', () => { + render( + , + ) + + expect(screen.getByText('workflow.tabs.noResult')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/factories.ts b/web/app/components/workflow/block-selector/__tests__/factories.ts new file mode 100644 index 0000000000..b7d82f7cb3 --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/factories.ts @@ -0,0 +1,101 @@ +import type { ToolWithProvider } from '../../types' +import type { Plugin } from '@/app/components/plugins/types' +import type { Tool } from '@/app/components/tools/types' +import { PluginCategoryEnum } from '@/app/components/plugins/types' +import { CollectionType } from '@/app/components/tools/types' +import { defaultSystemFeatures } from '@/types/feature' + +export const createTool = ( + name: string, + label: string, + description = `${label} description`, +): Tool => ({ + name, + author: 'author', + label: { + en_US: label, + zh_Hans: label, + }, + description: { + en_US: description, + zh_Hans: description, + }, + parameters: [], + labels: [], + output_schema: {}, +}) + +export const createToolProvider = ( + overrides: Partial = {}, +): ToolWithProvider => ({ + id: 'provider-1', + name: 'provider-one', + author: 'Provider Author', + description: { + en_US: 'Provider description', + zh_Hans: 'Provider description', + }, + icon: 'icon', + icon_dark: 'icon-dark', + label: { + en_US: 'Provider One', + zh_Hans: 'Provider One', + }, + type: CollectionType.builtIn, + team_credentials: {}, + is_team_authorization: false, + allow_delete: false, + labels: [], + plugin_id: 'plugin-1', + tools: [createTool('tool-a', 'Tool A')], + meta: { version: '1.0.0' } as ToolWithProvider['meta'], + plugin_unique_identifier: 'plugin-1@1.0.0', + ...overrides, +}) + +export const createPlugin = (overrides: Partial = {}): Plugin => ({ + type: 'plugin', + org: 'org', + author: 'author', + name: 'Plugin One', + plugin_id: 'plugin-1', + version: '1.0.0', + latest_version: '1.0.0', + latest_package_identifier: 'plugin-1@1.0.0', + icon: 'icon', + verified: true, + label: { + en_US: 'Plugin One', + zh_Hans: 'Plugin One', + }, + brief: { + en_US: 'Plugin description', + zh_Hans: 'Plugin description', + }, + description: { + en_US: 'Plugin description', + zh_Hans: 'Plugin description', + }, + introduction: 'Plugin introduction', + repository: 'https://example.com/plugin', + category: PluginCategoryEnum.tool, + tags: [], + badges: [], + install_count: 0, + endpoint: { + settings: [], + }, + verification: { + authorized_category: 'community', + }, + from: 'github', + ...overrides, +}) + +export const createGlobalPublicStoreState = (enableMarketplace: boolean) => ({ + systemFeatures: { + ...defaultSystemFeatures, + enable_marketplace: enableMarketplace, + }, + setSystemFeatures: vi.fn(), +}) diff --git a/web/app/components/workflow/block-selector/__tests__/featured-tools.spec.tsx b/web/app/components/workflow/block-selector/__tests__/featured-tools.spec.tsx new file mode 100644 index 0000000000..1720a2d897 --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/featured-tools.spec.tsx @@ -0,0 +1,101 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import FeaturedTools from '../featured-tools' +import { createPlugin, createToolProvider } from './factories' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +vi.mock('@/utils/var', async importOriginal => ({ + ...(await importOriginal()), + getMarketplaceUrl: () => 'https://marketplace.test/tools', +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +describe('FeaturedTools', () => { + beforeEach(() => { + vi.clearAllMocks() + localStorage.clear() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('shows more featured tools when the list exceeds the initial quota', async () => { + const user = userEvent.setup() + const plugins = Array.from({ length: 6 }, (_, index) => + createPlugin({ + plugin_id: `plugin-${index + 1}`, + latest_package_identifier: `plugin-${index + 1}@1.0.0`, + label: { en_US: `Plugin ${index + 1}`, zh_Hans: `Plugin ${index + 1}` }, + })) + const providers = plugins.map((plugin, index) => + createToolProvider({ + id: `provider-${index + 1}`, + plugin_id: plugin.plugin_id, + label: { en_US: `Provider ${index + 1}`, zh_Hans: `Provider ${index + 1}` }, + }), + ) + const providerMap = new Map(providers.map(provider => [provider.plugin_id!, provider])) + + render( + , + ) + + expect(screen.getByText('Provider 1')).toBeInTheDocument() + expect(screen.queryByText('Provider 6')).not.toBeInTheDocument() + + await user.click(screen.getByText('workflow.tabs.showMoreFeatured')) + + expect(screen.getByText('Provider 6')).toBeInTheDocument() + }) + + it('honors the persisted collapsed state', () => { + localStorage.setItem('workflow_tools_featured_collapsed', 'true') + + render( + , + ) + + expect(screen.getByText('workflow.tabs.featuredTools')).toBeInTheDocument() + expect(screen.queryByText('Provider One')).not.toBeInTheDocument() + }) + + it('shows the marketplace empty state when no featured tools are available', () => { + render( + , + ) + + expect(screen.getByText('workflow.tabs.noFeaturedPlugins')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/hooks.spec.tsx b/web/app/components/workflow/block-selector/__tests__/hooks.spec.tsx new file mode 100644 index 0000000000..6d27560802 --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/hooks.spec.tsx @@ -0,0 +1,52 @@ +import { act, renderHook } from '@testing-library/react' +import { useTabs, useToolTabs } from '../hooks' +import { TabsEnum, ToolTypeEnum } from '../types' + +describe('block-selector hooks', () => { + it('falls back to the first valid tab when the preferred start tab is disabled', () => { + const { result } = renderHook(() => useTabs({ + noStart: false, + hasUserInputNode: true, + defaultActiveTab: TabsEnum.Start, + })) + + expect(result.current.tabs.find(tab => tab.key === TabsEnum.Start)?.disabled).toBe(true) + expect(result.current.activeTab).toBe(TabsEnum.Blocks) + }) + + it('keeps the start tab enabled when forcing it on and resets to a valid tab after disabling blocks', () => { + const props: Parameters[0] = { + noBlocks: false, + noStart: false, + hasUserInputNode: true, + forceEnableStartTab: true, + } + + const { result, rerender } = renderHook(nextProps => useTabs(nextProps), { + initialProps: props, + }) + + expect(result.current.tabs.find(tab => tab.key === TabsEnum.Start)?.disabled).toBeFalsy() + + act(() => { + result.current.setActiveTab(TabsEnum.Blocks) + }) + + rerender({ + ...props, + noBlocks: true, + noSources: true, + noTools: true, + }) + + expect(result.current.activeTab).toBe(TabsEnum.Start) + }) + + it('returns the MCP tab only when it is not hidden', () => { + const { result: visible } = renderHook(() => useToolTabs()) + const { result: hidden } = renderHook(() => useToolTabs(true)) + + expect(visible.current.some(tab => tab.key === ToolTypeEnum.MCP)).toBe(true) + expect(hidden.current.some(tab => tab.key === ToolTypeEnum.MCP)).toBe(false) + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/index.spec.tsx b/web/app/components/workflow/block-selector/__tests__/index.spec.tsx new file mode 100644 index 0000000000..735a831c10 --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/index.spec.tsx @@ -0,0 +1,90 @@ +import type { NodeDefault, ToolWithProvider } from '../../types' +import { screen } from '@testing-library/react' +import { renderWorkflowComponent } from '../../__tests__/workflow-test-env' +import { BlockEnum } from '../../types' +import NodeSelectorWrapper from '../index' +import { BlockClassificationEnum } from '../types' + +vi.mock('reactflow', async () => + (await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock()) + +vi.mock('@/service/use-plugins', () => ({ + useFeaturedToolsRecommendations: () => ({ + plugins: [], + isLoading: false, + }), +})) + +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => ({ data: [] }), + useAllCustomTools: () => ({ data: [] }), + useAllWorkflowTools: () => ({ data: [] }), + useAllMCPTools: () => ({ data: [] }), + useInvalidateAllBuiltInTools: () => vi.fn(), +})) + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector({ + systemFeatures: { enable_marketplace: false }, + }), +})) + +const createBlock = (type: BlockEnum, title: string): NodeDefault => ({ + metaData: { + type, + title, + sort: 0, + classification: BlockClassificationEnum.Default, + author: 'Dify', + description: `${title} description`, + }, + defaultValue: {}, + checkValid: () => ({ isValid: true }), +}) + +const dataSource: ToolWithProvider = { + id: 'datasource-1', + name: 'datasource', + author: 'Dify', + description: { en_US: 'Data source', zh_Hans: '数据源' }, + icon: 'icon', + label: { en_US: 'Data Source', zh_Hans: 'Data Source' }, + type: 'datasource' as ToolWithProvider['type'], + team_credentials: {}, + is_team_authorization: false, + allow_delete: false, + labels: [], + tools: [], + meta: { version: '1.0.0' } as ToolWithProvider['meta'], +} + +describe('NodeSelectorWrapper', () => { + it('filters hidden block types from hooks store and forwards data sources', async () => { + renderWorkflowComponent( + , + { + hooksStoreProps: { + availableNodesMetaData: { + nodes: [ + createBlock(BlockEnum.Start, 'Start'), + createBlock(BlockEnum.Tool, 'Tool'), + createBlock(BlockEnum.Code, 'Code'), + createBlock(BlockEnum.DataSource, 'Data Source'), + ], + }, + }, + initialStoreState: { + dataSourceList: [dataSource], + }, + }, + ) + + expect(await screen.findByText('Code')).toBeInTheDocument() + expect(screen.queryByText('Start')).not.toBeInTheDocument() + expect(screen.queryByText('Tool')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/main.spec.tsx b/web/app/components/workflow/block-selector/__tests__/main.spec.tsx new file mode 100644 index 0000000000..1deb6ce84c --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/main.spec.tsx @@ -0,0 +1,95 @@ +import type { NodeDefault } from '../../types' +import { screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { renderWorkflowComponent } from '../../__tests__/workflow-test-env' +import { BlockEnum } from '../../types' +import NodeSelector from '../main' +import { BlockClassificationEnum } from '../types' + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: () => [], + }), + }), +})) + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector({ + systemFeatures: { enable_marketplace: false }, + }), +})) + +vi.mock('@/service/use-plugins', () => ({ + useFeaturedToolsRecommendations: () => ({ + plugins: [], + isLoading: false, + }), +})) + +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => ({ data: [] }), + useAllCustomTools: () => ({ data: [] }), + useAllWorkflowTools: () => ({ data: [] }), + useAllMCPTools: () => ({ data: [] }), + useInvalidateAllBuiltInTools: () => vi.fn(), +})) + +const createBlock = (type: BlockEnum, title: string): NodeDefault => ({ + metaData: { + classification: BlockClassificationEnum.Default, + sort: 0, + type, + title, + author: 'Dify', + description: `${title} description`, + }, + defaultValue: {}, + checkValid: () => ({ isValid: true }), +}) + +describe('NodeSelector', () => { + it('opens with the real blocks tab, filters by search, selects a block, and clears search after close', async () => { + const user = userEvent.setup() + const onSelect = vi.fn() + + renderWorkflowComponent( + ( + + )} + />, + ) + + await user.click(screen.getByRole('button', { name: 'selector-closed' })) + + const searchInput = screen.getByPlaceholderText('workflow.tabs.searchBlock') + expect(screen.getByText('LLM')).toBeInTheDocument() + expect(screen.getByText('End')).toBeInTheDocument() + + await user.type(searchInput, 'LLM') + expect(screen.getByText('LLM')).toBeInTheDocument() + expect(screen.queryByText('End')).not.toBeInTheDocument() + + await user.click(screen.getByText('LLM')) + + expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM, undefined) + await waitFor(() => { + expect(screen.queryByPlaceholderText('workflow.tabs.searchBlock')).not.toBeInTheDocument() + }) + + await user.click(screen.getByRole('button', { name: 'selector-closed' })) + + const reopenedInput = screen.getByPlaceholderText('workflow.tabs.searchBlock') as HTMLInputElement + expect(reopenedInput.value).toBe('') + expect(screen.getByText('End')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/tools.spec.tsx b/web/app/components/workflow/block-selector/__tests__/tools.spec.tsx new file mode 100644 index 0000000000..a800342e6e --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/tools.spec.tsx @@ -0,0 +1,95 @@ +import { render, screen } from '@testing-library/react' +import { CollectionType } from '@/app/components/tools/types' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import Tools from '../tools' +import { ViewType } from '../view-type-select' +import { createToolProvider } from './factories' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +describe('Tools', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('shows the empty state when there are no tools and no search text', () => { + render( + , + ) + + expect(screen.getByText('No tools available')).toBeInTheDocument() + }) + + it('renders tree groups for built-in and custom providers', () => { + render( + , + ) + + expect(screen.getByText('Built In')).toBeInTheDocument() + expect(screen.getByText('workflow.tabs.customTool')).toBeInTheDocument() + expect(screen.getByText('Built In Provider')).toBeInTheDocument() + expect(screen.getByText('Custom Provider')).toBeInTheDocument() + }) + + it('shows the alphabetical index in flat view when enough tools are present', () => { + const { container } = render( + + createToolProvider({ + id: `provider-${index}`, + label: { + en_US: `${String.fromCharCode(65 + index)} Provider`, + zh_Hans: `${String.fromCharCode(65 + index)} Provider`, + }, + }))} + onSelect={vi.fn()} + viewType={ViewType.flat} + hasSearchText={false} + />, + ) + + expect(container.querySelector('.index-bar')).toBeInTheDocument() + expect(screen.getByText('A Provider')).toBeInTheDocument() + expect(screen.getByText('K Provider')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/tool/__tests__/tool.spec.tsx b/web/app/components/workflow/block-selector/tool/__tests__/tool.spec.tsx new file mode 100644 index 0000000000..d9fad38854 --- /dev/null +++ b/web/app/components/workflow/block-selector/tool/__tests__/tool.spec.tsx @@ -0,0 +1,99 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { trackEvent } from '@/app/components/base/amplitude' +import { CollectionType } from '@/app/components/tools/types' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { BlockEnum } from '../../../types' +import { createTool, createToolProvider } from '../../__tests__/factories' +import { ViewType } from '../../view-type-select' +import Tool from '../tool' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/base/amplitude', () => ({ + trackEvent: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) +const mockTrackEvent = vi.mocked(trackEvent) + +describe('Tool', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('expands a provider and selects an action item', async () => { + const user = userEvent.setup() + const onSelect = vi.fn() + + render( + , + ) + + await user.click(screen.getByText('Provider One')) + await user.click(screen.getByText('Tool B')) + + expect(onSelect).toHaveBeenCalledWith(BlockEnum.Tool, expect.objectContaining({ + provider_id: 'provider-1', + provider_name: 'provider-one', + tool_name: 'tool-b', + title: 'Tool B', + })) + expect(mockTrackEvent).toHaveBeenCalledWith('tool_selected', { + tool_name: 'tool-b', + plugin_id: 'plugin-1', + }) + }) + + it('selects workflow tools directly without expanding the provider', async () => { + const user = userEvent.setup() + const onSelect = vi.fn() + + render( + , + ) + + await user.click(screen.getByText('Workflow Tool')) + + expect(onSelect).toHaveBeenCalledWith(BlockEnum.Tool, expect.objectContaining({ + provider_type: CollectionType.workflow, + tool_name: 'workflow-tool', + tool_label: 'Workflow Tool', + })) + }) +}) diff --git a/web/app/components/workflow/block-selector/tool/tool-list-flat-view/__tests__/list.spec.tsx b/web/app/components/workflow/block-selector/tool/tool-list-flat-view/__tests__/list.spec.tsx new file mode 100644 index 0000000000..ecb5dfe0a6 --- /dev/null +++ b/web/app/components/workflow/block-selector/tool/tool-list-flat-view/__tests__/list.spec.tsx @@ -0,0 +1,66 @@ +import { render, screen } from '@testing-library/react' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { createToolProvider } from '../../../__tests__/factories' +import List from '../list' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +describe('ToolListFlatView', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('assigns the first tool of each letter to the shared refs and renders the index bar', () => { + const toolRefs = { + current: {} as Record, + } + + render( + ), + createToolProvider({ + id: 'provider-b', + label: { en_US: 'B Provider', zh_Hans: 'B Provider' }, + letter: 'B', + } as ReturnType), + ]} + isShowLetterIndex + indexBar={
} + hasSearchText={false} + onSelect={vi.fn()} + toolRefs={toolRefs} + />, + ) + + expect(screen.getByText('A Provider')).toBeInTheDocument() + expect(screen.getByText('B Provider')).toBeInTheDocument() + expect(screen.getByTestId('index-bar')).toBeInTheDocument() + expect(toolRefs.current.A).toBeTruthy() + expect(toolRefs.current.B).toBeTruthy() + }) +}) diff --git a/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/item.spec.tsx b/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/item.spec.tsx new file mode 100644 index 0000000000..027ad7c11c --- /dev/null +++ b/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/item.spec.tsx @@ -0,0 +1,47 @@ +import { render, screen } from '@testing-library/react' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { createToolProvider } from '../../../__tests__/factories' +import Item from '../item' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +describe('ToolListTreeView Item', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('renders the group heading and its provider list', () => { + render( + , + ) + + expect(screen.getByText('My Group')).toBeInTheDocument() + expect(screen.getByText('Provider Alpha')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/list.spec.tsx b/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/list.spec.tsx new file mode 100644 index 0000000000..7b3c083e85 --- /dev/null +++ b/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/list.spec.tsx @@ -0,0 +1,56 @@ +import { render, screen } from '@testing-library/react' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { createToolProvider } from '../../../__tests__/factories' +import { CUSTOM_GROUP_NAME } from '../../../index-bar' +import List from '../list' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +describe('ToolListTreeView', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('translates built-in special group names and renders the nested providers', () => { + render( + , + ) + + expect(screen.getByText('BuiltIn')).toBeInTheDocument() + expect(screen.getByText('workflow.tabs.customTool')).toBeInTheDocument() + expect(screen.getByText('Built In Provider')).toBeInTheDocument() + expect(screen.getByText('Custom Provider')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/datasets-detail-store/__tests__/store.spec.tsx b/web/app/components/workflow/datasets-detail-store/__tests__/store.spec.tsx new file mode 100644 index 0000000000..a031c6370e --- /dev/null +++ b/web/app/components/workflow/datasets-detail-store/__tests__/store.spec.tsx @@ -0,0 +1,91 @@ +import type { DataSet } from '@/models/datasets' +import { renderHook } from '@testing-library/react' +import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets' +import { DatasetsDetailContext } from '../provider' +import { createDatasetsDetailStore, useDatasetsDetailStore } from '../store' + +const createDataset = (id: string, name = `dataset-${id}`): DataSet => ({ + id, + name, + indexing_status: 'completed', + icon_info: { + icon: 'book', + icon_type: 'emoji' as DataSet['icon_info']['icon_type'], + }, + description: `${name} description`, + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: 'high_quality' as DataSet['indexing_technique'], + created_by: 'user-1', + updated_by: 'user-1', + updated_at: 1, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 0, + total_document_count: 0, + word_count: 0, + provider: 'provider', + embedding_model: 'model', + embedding_model_provider: 'provider', + embedding_available: true, + retrieval_model_dict: {} as DataSet['retrieval_model_dict'], + retrieval_model: {} as DataSet['retrieval_model'], + tags: [], + external_knowledge_info: { + external_knowledge_id: '', + external_knowledge_api_id: '', + external_knowledge_api_name: '', + external_knowledge_api_endpoint: '', + }, + external_retrieval_model: { + top_k: 1, + score_threshold: 0, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + runtime_mode: 'general', + enable_api: false, + is_multimodal: false, +}) + +describe('datasets-detail-store store', () => { + it('merges dataset details by id', () => { + const store = createDatasetsDetailStore() + + store.getState().updateDatasetsDetail([ + createDataset('dataset-1', 'Dataset One'), + createDataset('dataset-2', 'Dataset Two'), + ]) + store.getState().updateDatasetsDetail([ + createDataset('dataset-2', 'Dataset Two Updated'), + ]) + + expect(store.getState().datasetsDetail).toMatchObject({ + 'dataset-1': { name: 'Dataset One' }, + 'dataset-2': { name: 'Dataset Two Updated' }, + }) + }) + + it('reads state from the datasets detail context', () => { + const store = createDatasetsDetailStore() + store.getState().updateDatasetsDetail([createDataset('dataset-3')]) + const wrapper = ({ children }: { children: React.ReactNode }) => ( + + {children} + + ) + + const { result } = renderHook( + () => useDatasetsDetailStore(state => state.datasetsDetail['dataset-3']?.name), + { wrapper }, + ) + + expect(result.current).toBe('dataset-dataset-3') + }) + + it('throws when the datasets detail provider is missing', () => { + expect(() => renderHook(() => useDatasetsDetailStore(state => state.datasetsDetail))).toThrow( + 'Missing DatasetsDetailContext.Provider in the tree', + ) + }) +}) diff --git a/web/app/components/workflow/hooks-store/__tests__/store.spec.tsx b/web/app/components/workflow/hooks-store/__tests__/store.spec.tsx new file mode 100644 index 0000000000..131290b834 --- /dev/null +++ b/web/app/components/workflow/hooks-store/__tests__/store.spec.tsx @@ -0,0 +1,41 @@ +import { renderHook } from '@testing-library/react' +import { HooksStoreContext } from '../provider' +import { createHooksStore, useHooksStore } from '../store' + +describe('hooks-store store', () => { + it('creates default callbacks and refreshes selected handlers', () => { + const store = createHooksStore({}) + const handleBackupDraft = vi.fn() + + expect(store.getState().availableNodesMetaData).toEqual({ nodes: [] }) + expect(store.getState().hasNodeInspectVars('node-1')).toBe(false) + expect(store.getState().getWorkflowRunAndTraceUrl('run-1')).toEqual({ + runUrl: '', + traceUrl: '', + }) + + store.getState().refreshAll({ handleBackupDraft }) + + expect(store.getState().handleBackupDraft).toBe(handleBackupDraft) + }) + + it('reads state from the hooks store context', () => { + const handleRun = vi.fn() + const store = createHooksStore({ handleRun }) + const wrapper = ({ children }: { children: React.ReactNode }) => ( + + {children} + + ) + + const { result } = renderHook(() => useHooksStore(state => state.handleRun), { wrapper }) + + expect(result.current).toBe(handleRun) + }) + + it('throws when the hooks store provider is missing', () => { + expect(() => renderHook(() => useHooksStore(state => state.handleRun))).toThrow( + 'Missing HooksStoreContext.Provider in the tree', + ) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-DSL.spec.ts b/web/app/components/workflow/hooks/__tests__/use-DSL.spec.ts new file mode 100644 index 0000000000..f10777ae69 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-DSL.spec.ts @@ -0,0 +1,19 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useDSL } from '../use-DSL' + +describe('useDSL', () => { + it('returns the DSL handlers from hooks store', () => { + const exportCheck = vi.fn() + const handleExportDSL = vi.fn() + + const { result } = renderWorkflowHook(() => useDSL(), { + hooksStoreProps: { + exportCheck, + handleExportDSL, + }, + }) + + expect(result.current.exportCheck).toBe(exportCheck) + expect(result.current.handleExportDSL).toBe(handleExportDSL) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-edges-interactions-without-sync.spec.ts b/web/app/components/workflow/hooks/__tests__/use-edges-interactions-without-sync.spec.ts new file mode 100644 index 0000000000..b38aca6398 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-edges-interactions-without-sync.spec.ts @@ -0,0 +1,90 @@ +import { act, waitFor } from '@testing-library/react' +import { useEdges } from 'reactflow' +import { createEdge, createNode } from '../../__tests__/fixtures' +import { renderWorkflowFlowHook } from '../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../types' +import { useEdgesInteractionsWithoutSync } from '../use-edges-interactions-without-sync' + +type EdgeRuntimeState = { + _sourceRunningStatus?: NodeRunningStatus + _targetRunningStatus?: NodeRunningStatus + _waitingRun?: boolean +} + +const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState => + (edge?.data ?? {}) as EdgeRuntimeState + +const createFlowNodes = () => [ + createNode({ id: 'a' }), + createNode({ id: 'b' }), + createNode({ id: 'c' }), +] + +const createFlowEdges = () => [ + createEdge({ + id: 'e1', + source: 'a', + target: 'b', + data: { + _sourceRunningStatus: NodeRunningStatus.Running, + _targetRunningStatus: NodeRunningStatus.Running, + _waitingRun: true, + }, + }), + createEdge({ + id: 'e2', + source: 'b', + target: 'c', + data: { + _sourceRunningStatus: NodeRunningStatus.Succeeded, + _targetRunningStatus: undefined, + _waitingRun: false, + }, + }), +] + +const renderEdgesInteractionsHook = () => + renderWorkflowFlowHook(() => ({ + ...useEdgesInteractionsWithoutSync(), + edges: useEdges(), + }), { + nodes: createFlowNodes(), + edges: createFlowEdges(), + }) + +describe('useEdgesInteractionsWithoutSync', () => { + it('clears running status and waitingRun on all edges', () => { + const { result } = renderEdgesInteractionsHook() + + act(() => { + result.current.handleEdgeCancelRunningStatus() + }) + + return waitFor(() => { + result.current.edges.forEach((edge) => { + const edgeState = getEdgeRuntimeState(edge) + expect(edgeState._sourceRunningStatus).toBeUndefined() + expect(edgeState._targetRunningStatus).toBeUndefined() + expect(edgeState._waitingRun).toBe(false) + }) + }) + }) + + it('does not mutate the original edges array', () => { + const edges = createFlowEdges() + const originalData = { ...getEdgeRuntimeState(edges[0]) } + const { result } = renderWorkflowFlowHook(() => ({ + ...useEdgesInteractionsWithoutSync(), + edges: useEdges(), + }), { + nodes: createFlowNodes(), + edges, + }) + + act(() => { + result.current.handleEdgeCancelRunningStatus() + }) + + expect(getEdgeRuntimeState(edges[0])._sourceRunningStatus).toBe(originalData._sourceRunningStatus) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-edges-interactions.helpers.spec.ts b/web/app/components/workflow/hooks/__tests__/use-edges-interactions.helpers.spec.ts new file mode 100644 index 0000000000..3741bcc653 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-edges-interactions.helpers.spec.ts @@ -0,0 +1,114 @@ +import { createEdge, createNode } from '../../__tests__/fixtures' +import { getNodesConnectedSourceOrTargetHandleIdsMap } from '../../utils' +import { + applyConnectedHandleNodeData, + buildContextMenuEdges, + clearEdgeMenuIfNeeded, + clearNodeSelectionState, + updateEdgeHoverState, + updateEdgeSelectionState, +} from '../use-edges-interactions.helpers' + +vi.mock('../../utils', () => ({ + getNodesConnectedSourceOrTargetHandleIdsMap: vi.fn(), +})) + +const mockGetNodesConnectedSourceOrTargetHandleIdsMap = vi.mocked(getNodesConnectedSourceOrTargetHandleIdsMap) + +describe('use-edges-interactions.helpers', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('applyConnectedHandleNodeData should merge connected handle metadata into matching nodes', () => { + mockGetNodesConnectedSourceOrTargetHandleIdsMap.mockReturnValue({ + 'node-1': { + _connectedSourceHandleIds: ['branch-a'], + }, + }) + + const nodes = [ + createNode({ id: 'node-1', data: { title: 'Source' } }), + createNode({ id: 'node-2', data: { title: 'Target' } }), + ] + const edgeChanges = [{ + type: 'add', + edge: createEdge({ id: 'edge-1', source: 'node-1', target: 'node-2' }), + }] + + const result = applyConnectedHandleNodeData(nodes, edgeChanges) + + expect(result[0].data._connectedSourceHandleIds).toEqual(['branch-a']) + expect(result[1].data._connectedSourceHandleIds).toEqual([]) + expect(mockGetNodesConnectedSourceOrTargetHandleIdsMap).toHaveBeenCalledWith(edgeChanges, nodes) + }) + + it('clearEdgeMenuIfNeeded should return true only when the open menu belongs to a removed edge', () => { + expect(clearEdgeMenuIfNeeded({ + edgeMenu: { edgeId: 'edge-1' }, + edgeIds: ['edge-1', 'edge-2'], + })).toBe(true) + + expect(clearEdgeMenuIfNeeded({ + edgeMenu: { edgeId: 'edge-3' }, + edgeIds: ['edge-1', 'edge-2'], + })).toBe(false) + + expect(clearEdgeMenuIfNeeded({ + edgeIds: ['edge-1'], + })).toBe(false) + }) + + it('updateEdgeHoverState should toggle only the hovered edge flag', () => { + const edges = [ + createEdge({ id: 'edge-1', data: { _hovering: false } }), + createEdge({ id: 'edge-2', data: { _hovering: false } }), + ] + + const result = updateEdgeHoverState(edges, 'edge-2', true) + + expect(result.find(edge => edge.id === 'edge-1')?.data._hovering).toBe(false) + expect(result.find(edge => edge.id === 'edge-2')?.data._hovering).toBe(true) + }) + + it('updateEdgeSelectionState should update selected flags for select changes only', () => { + const edges = [ + createEdge({ id: 'edge-1', selected: false }), + createEdge({ id: 'edge-2', selected: true }), + ] + + const result = updateEdgeSelectionState(edges, [ + { type: 'select', id: 'edge-1', selected: true }, + { type: 'remove', id: 'edge-2' }, + ]) + + expect(result.find(edge => edge.id === 'edge-1')?.selected).toBe(true) + expect(result.find(edge => edge.id === 'edge-2')?.selected).toBe(true) + }) + + it('buildContextMenuEdges should select the target edge and clear bundled markers', () => { + const edges = [ + createEdge({ id: 'edge-1', selected: true, data: { _isBundled: true } }), + createEdge({ id: 'edge-2', selected: false, data: { _isBundled: true } }), + ] + + const result = buildContextMenuEdges(edges, 'edge-2') + + expect(result.find(edge => edge.id === 'edge-1')?.selected).toBe(false) + expect(result.find(edge => edge.id === 'edge-2')?.selected).toBe(true) + expect(result.every(edge => edge.data._isBundled === false)).toBe(true) + }) + + it('clearNodeSelectionState should clear selected state and bundled markers on every node', () => { + const nodes = [ + createNode({ id: 'node-1', selected: true, data: { selected: true, _isBundled: true } }), + createNode({ id: 'node-2', selected: false, data: { selected: true, _isBundled: true } }), + ] + + const result = clearNodeSelectionState(nodes) + + expect(result.every(node => node.selected === false)).toBe(true) + expect(result.every(node => node.data.selected === false)).toBe(true) + expect(result.every(node => node.data._isBundled === false)).toBe(true) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-fetch-workflow-inspect-vars.spec.ts b/web/app/components/workflow/hooks/__tests__/use-fetch-workflow-inspect-vars.spec.ts new file mode 100644 index 0000000000..e1e26732ae --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-fetch-workflow-inspect-vars.spec.ts @@ -0,0 +1,187 @@ +import type { SchemaTypeDefinition } from '@/service/use-common' +import type { VarInInspect } from '@/types/workflow' +import { act, waitFor } from '@testing-library/react' +import { FlowType } from '@/types/common' +import { createNode } from '../../__tests__/fixtures' +import { resetReactFlowMockState, rfState } from '../../__tests__/reactflow-mock-state' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum, VarType } from '../../types' +import { useSetWorkflowVarsWithValue } from '../use-fetch-workflow-inspect-vars' + +const mockFetchAllInspectVars = vi.hoisted(() => vi.fn()) +const mockInvalidateConversationVarValues = vi.hoisted(() => vi.fn()) +const mockInvalidateSysVarValues = vi.hoisted(() => vi.fn()) +const mockHandleCancelAllNodeSuccessStatus = vi.hoisted(() => vi.fn()) +const mockToNodeOutputVars = vi.hoisted(() => vi.fn()) + +const schemaTypeDefinitions: SchemaTypeDefinition[] = [{ + name: 'simple', + schema: { + properties: {}, + }, +}] + +vi.mock('reactflow', async () => + (await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock()) + +vi.mock('@/service/use-tools', async () => + (await import('../../__tests__/service-mock-factory')).createToolServiceMock()) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidateConversationVarValues: () => mockInvalidateConversationVarValues, + useInvalidateSysVarValues: () => mockInvalidateSysVarValues, +})) + +vi.mock('@/service/workflow', () => ({ + fetchAllInspectVars: (...args: unknown[]) => mockFetchAllInspectVars(...args), +})) + +vi.mock('../use-nodes-interactions-without-sync', () => ({ + useNodesInteractionsWithoutSync: () => ({ + handleCancelAllNodeSuccessStatus: mockHandleCancelAllNodeSuccessStatus, + }), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/use-match-schema-type', () => ({ + default: () => ({ + schemaTypeDefinitions, + }), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/utils', () => ({ + toNodeOutputVars: (...args: unknown[]) => mockToNodeOutputVars(...args), +})) + +const createInspectVar = (overrides: Partial = {}): VarInInspect => ({ + id: 'var-1', + type: 'node', + name: 'answer', + description: 'Answer', + selector: ['node-1', 'answer'], + value_type: VarType.string, + value: 'hello', + edited: false, + visible: true, + is_truncated: false, + full_content: { + size_bytes: 5, + download_url: 'https://example.com/answer.txt', + }, + ...overrides, +}) + +describe('use-fetch-workflow-inspect-vars', () => { + beforeEach(() => { + vi.clearAllMocks() + resetReactFlowMockState() + rfState.nodes = [ + createNode({ + id: 'node-1', + data: { + type: BlockEnum.Code, + title: 'Code', + desc: '', + }, + }), + ] + mockToNodeOutputVars.mockReturnValue([{ + nodeId: 'node-1', + vars: [{ + variable: 'answer', + schemaType: 'simple', + }], + }]) + }) + + it('fetches inspect vars, invalidates cached values, and stores schema-enriched node vars', async () => { + mockFetchAllInspectVars.mockResolvedValue([ + createInspectVar(), + createInspectVar({ + id: 'missing-node-var', + selector: ['missing-node', 'answer'], + }), + ]) + + const { result, store } = renderWorkflowHook( + () => useSetWorkflowVarsWithValue({ + flowType: FlowType.appFlow, + flowId: 'flow-1', + }), + { + initialStoreState: { + dataSourceList: [], + }, + }, + ) + + await act(async () => { + await result.current.fetchInspectVars({}) + }) + + expect(mockInvalidateConversationVarValues).toHaveBeenCalledTimes(1) + expect(mockInvalidateSysVarValues).toHaveBeenCalledTimes(1) + expect(mockFetchAllInspectVars).toHaveBeenCalledWith(FlowType.appFlow, 'flow-1') + expect(mockHandleCancelAllNodeSuccessStatus).toHaveBeenCalledTimes(1) + expect(store.getState().nodesWithInspectVars).toEqual([ + expect.objectContaining({ + nodeId: 'node-1', + nodeType: BlockEnum.Code, + title: 'Code', + vars: [ + expect.objectContaining({ + id: 'var-1', + selector: ['node-1', 'answer'], + schemaType: 'simple', + value: 'hello', + }), + ], + }), + ]) + }) + + it('accepts passed-in vars and plugin metadata without refetching from the API', async () => { + const passedInVars = [ + createInspectVar({ + id: 'var-2', + value: 'passed-in', + }), + ] + const passedInPluginInfo = { + buildInTools: [], + customTools: [], + workflowTools: [], + mcpTools: [], + dataSourceList: [], + } + + const { result, store } = renderWorkflowHook( + () => useSetWorkflowVarsWithValue({ + flowType: FlowType.appFlow, + flowId: 'flow-2', + }), + { + initialStoreState: { + dataSourceList: [], + }, + }, + ) + + await act(async () => { + await result.current.fetchInspectVars({ + passInVars: true, + vars: passedInVars, + passedInAllPluginInfoList: passedInPluginInfo, + passedInSchemaTypeDefinitions: schemaTypeDefinitions, + }) + }) + + await waitFor(() => { + expect(mockFetchAllInspectVars).not.toHaveBeenCalled() + expect(store.getState().nodesWithInspectVars[0]?.vars[0]).toMatchObject({ + id: 'var-2', + value: 'passed-in', + schemaType: 'simple', + }) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud-common.spec.ts b/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud-common.spec.ts new file mode 100644 index 0000000000..7b2006aa77 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud-common.spec.ts @@ -0,0 +1,210 @@ +import type { SchemaTypeDefinition } from '@/service/use-common' +import type { VarInInspect } from '@/types/workflow' +import { act, waitFor } from '@testing-library/react' +import { FlowType } from '@/types/common' +import { createNode } from '../../__tests__/fixtures' +import { resetReactFlowMockState, rfState } from '../../__tests__/reactflow-mock-state' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum, VarType } from '../../types' +import { useInspectVarsCrudCommon } from '../use-inspect-vars-crud-common' + +const mockFetchNodeInspectVars = vi.hoisted(() => vi.fn()) +const mockDoDeleteAllInspectorVars = vi.hoisted(() => vi.fn()) +const mockInvalidateConversationVarValues = vi.hoisted(() => vi.fn()) +const mockInvalidateSysVarValues = vi.hoisted(() => vi.fn()) +const mockHandleCancelNodeSuccessStatus = vi.hoisted(() => vi.fn()) +const mockHandleEdgeCancelRunningStatus = vi.hoisted(() => vi.fn()) +const mockToNodeOutputVars = vi.hoisted(() => vi.fn()) + +const schemaTypeDefinitions: SchemaTypeDefinition[] = [{ + name: 'simple', + schema: { + properties: {}, + }, +}] + +vi.mock('reactflow', async () => + (await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock()) + +vi.mock('@/service/use-flow', () => ({ + default: () => ({ + useInvalidateConversationVarValues: () => mockInvalidateConversationVarValues, + useInvalidateSysVarValues: () => mockInvalidateSysVarValues, + useResetConversationVar: () => ({ mutateAsync: vi.fn() }), + useResetToLastRunValue: () => ({ mutateAsync: vi.fn() }), + useDeleteAllInspectorVars: () => ({ mutateAsync: mockDoDeleteAllInspectorVars }), + useDeleteNodeInspectorVars: () => ({ mutate: vi.fn() }), + useDeleteInspectVar: () => ({ mutate: vi.fn() }), + useEditInspectorVar: () => ({ mutateAsync: vi.fn() }), + }), +})) + +vi.mock('@/service/use-tools', async () => + (await import('../../__tests__/service-mock-factory')).createToolServiceMock()) + +vi.mock('@/service/workflow', () => ({ + fetchNodeInspectVars: (...args: unknown[]) => mockFetchNodeInspectVars(...args), +})) + +vi.mock('../use-nodes-interactions-without-sync', () => ({ + useNodesInteractionsWithoutSync: () => ({ + handleCancelNodeSuccessStatus: mockHandleCancelNodeSuccessStatus, + }), +})) + +vi.mock('../use-edges-interactions-without-sync', () => ({ + useEdgesInteractionsWithoutSync: () => ({ + handleEdgeCancelRunningStatus: mockHandleEdgeCancelRunningStatus, + }), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/utils', async importOriginal => ({ + ...(await importOriginal()), + toNodeOutputVars: (...args: unknown[]) => mockToNodeOutputVars(...args), +})) + +const createInspectVar = (overrides: Partial = {}): VarInInspect => ({ + id: 'var-1', + type: 'node', + name: 'answer', + description: 'Answer', + selector: ['node-1', 'answer'], + value_type: VarType.string, + value: 'hello', + edited: false, + visible: true, + is_truncated: false, + full_content: { + size_bytes: 5, + download_url: 'https://example.com/answer.txt', + }, + ...overrides, +}) + +describe('useInspectVarsCrudCommon', () => { + beforeEach(() => { + vi.clearAllMocks() + resetReactFlowMockState() + rfState.nodes = [ + createNode({ + id: 'node-1', + data: { + type: BlockEnum.Code, + title: 'Code', + desc: '', + }, + }), + ] + mockToNodeOutputVars.mockReturnValue([{ + nodeId: 'node-1', + vars: [{ + variable: 'answer', + schemaType: 'simple', + }], + }]) + }) + + it('invalidates cached system vars without refetching node values for system selectors', async () => { + const { result } = renderWorkflowHook( + () => useInspectVarsCrudCommon({ + flowId: 'flow-1', + flowType: FlowType.appFlow, + }), + { + initialStoreState: { + dataSourceList: [], + }, + }, + ) + + await act(async () => { + await result.current.fetchInspectVarValue(['sys', 'query'], schemaTypeDefinitions) + }) + + expect(mockInvalidateSysVarValues).toHaveBeenCalledTimes(1) + expect(mockFetchNodeInspectVars).not.toHaveBeenCalled() + }) + + it('fetches node inspect vars, adds schema types, and marks the node as fetched', async () => { + mockFetchNodeInspectVars.mockResolvedValue([ + createInspectVar(), + ]) + + const { result, store } = renderWorkflowHook( + () => useInspectVarsCrudCommon({ + flowId: 'flow-1', + flowType: FlowType.appFlow, + }), + { + initialStoreState: { + dataSourceList: [], + nodesWithInspectVars: [{ + nodeId: 'node-1', + nodePayload: { + type: BlockEnum.Code, + title: 'Code', + desc: '', + } as never, + nodeType: BlockEnum.Code, + title: 'Code', + vars: [], + }], + }, + }, + ) + + await act(async () => { + await result.current.fetchInspectVarValue(['node-1', 'answer'], schemaTypeDefinitions) + }) + + await waitFor(() => { + expect(mockFetchNodeInspectVars).toHaveBeenCalledWith(FlowType.appFlow, 'flow-1', 'node-1') + expect(store.getState().nodesWithInspectVars[0]).toMatchObject({ + nodeId: 'node-1', + isValueFetched: true, + vars: [ + expect.objectContaining({ + id: 'var-1', + schemaType: 'simple', + }), + ], + }) + }) + }) + + it('deletes all inspect vars, invalidates cached values, and clears edge running state', async () => { + mockDoDeleteAllInspectorVars.mockResolvedValue(undefined) + + const { result, store } = renderWorkflowHook( + () => useInspectVarsCrudCommon({ + flowId: 'flow-1', + flowType: FlowType.appFlow, + }), + { + initialStoreState: { + nodesWithInspectVars: [{ + nodeId: 'node-1', + nodePayload: { + type: BlockEnum.Code, + title: 'Code', + desc: '', + } as never, + nodeType: BlockEnum.Code, + title: 'Code', + vars: [createInspectVar()], + }], + }, + }, + ) + + await act(async () => { + await result.current.deleteAllInspectorVars() + }) + + expect(mockDoDeleteAllInspectorVars).toHaveBeenCalledTimes(1) + expect(mockInvalidateConversationVarValues).toHaveBeenCalledTimes(1) + expect(mockInvalidateSysVarValues).toHaveBeenCalledTimes(1) + expect(mockHandleEdgeCancelRunningStatus).toHaveBeenCalledTimes(1) + expect(store.getState().nodesWithInspectVars).toEqual([]) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud.spec.ts b/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud.spec.ts new file mode 100644 index 0000000000..193e4307de --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud.spec.ts @@ -0,0 +1,135 @@ +import type { VarInInspect } from '@/types/workflow' +import { FlowType } from '@/types/common' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum, VarType } from '../../types' +import useInspectVarsCrud from '../use-inspect-vars-crud' + +const mockUseConversationVarValues = vi.hoisted(() => vi.fn()) +const mockUseSysVarValues = vi.hoisted(() => vi.fn()) + +vi.mock('@/service/use-workflow', () => ({ + useConversationVarValues: (...args: unknown[]) => mockUseConversationVarValues(...args), + useSysVarValues: (...args: unknown[]) => mockUseSysVarValues(...args), +})) + +const createInspectVar = (overrides: Partial = {}): VarInInspect => ({ + id: 'var-1', + type: 'node', + name: 'answer', + description: 'Answer', + selector: ['node-1', 'answer'], + value_type: VarType.string, + value: 'hello', + edited: false, + visible: true, + is_truncated: false, + full_content: { + size_bytes: 5, + download_url: 'https://example.com/answer.txt', + }, + ...overrides, +}) + +describe('useInspectVarsCrud', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseConversationVarValues.mockReturnValue({ + data: [createInspectVar({ + id: 'conversation-var', + name: 'history', + selector: ['conversation', 'history'], + })], + }) + mockUseSysVarValues.mockReturnValue({ + data: [ + createInspectVar({ + id: 'query-var', + name: 'query', + selector: ['sys', 'query'], + }), + createInspectVar({ + id: 'files-var', + name: 'files', + selector: ['sys', 'files'], + }), + createInspectVar({ + id: 'time-var', + name: 'time', + selector: ['sys', 'time'], + }), + ], + }) + }) + + it('appends query/files system vars to start-node inspect vars and filters them from the system list', () => { + const hasNodeInspectVars = vi.fn(() => true) + const deleteAllInspectorVars = vi.fn() + const fetchInspectVarValue = vi.fn() + + const { result } = renderWorkflowHook(() => useInspectVarsCrud(), { + initialStoreState: { + nodesWithInspectVars: [{ + nodeId: 'start-node', + nodePayload: { + type: BlockEnum.Start, + title: 'Start', + desc: '', + } as never, + nodeType: BlockEnum.Start, + title: 'Start', + vars: [createInspectVar({ + id: 'start-answer', + selector: ['start-node', 'answer'], + })], + }], + }, + hooksStoreProps: { + configsMap: { + flowId: 'flow-1', + flowType: FlowType.appFlow, + fileSettings: {} as never, + }, + hasNodeInspectVars, + fetchInspectVarValue, + editInspectVarValue: vi.fn(), + renameInspectVarName: vi.fn(), + appendNodeInspectVars: vi.fn(), + deleteInspectVar: vi.fn(), + deleteNodeInspectorVars: vi.fn(), + deleteAllInspectorVars, + isInspectVarEdited: vi.fn(() => false), + resetToLastRunVar: vi.fn(), + invalidateSysVarValues: vi.fn(), + resetConversationVar: vi.fn(), + invalidateConversationVarValues: vi.fn(), + hasSetInspectVar: vi.fn(() => false), + }, + }) + + expect(result.current.conversationVars).toHaveLength(1) + expect(result.current.systemVars.map(item => item.name)).toEqual(['time']) + expect(result.current.nodesWithInspectVars[0]?.vars.map(item => item.name)).toEqual([ + 'answer', + 'query', + 'files', + ]) + expect(result.current.hasNodeInspectVars).toBe(hasNodeInspectVars) + expect(result.current.fetchInspectVarValue).toBe(fetchInspectVarValue) + expect(result.current.deleteAllInspectorVars).toBe(deleteAllInspectorVars) + }) + + it('uses an empty flow id for rag pipeline conversation and system value queries', () => { + renderWorkflowHook(() => useInspectVarsCrud(), { + hooksStoreProps: { + configsMap: { + flowId: 'rag-flow', + flowType: FlowType.ragPipeline, + fileSettings: {} as never, + }, + }, + }) + + expect(mockUseConversationVarValues).toHaveBeenCalledWith(FlowType.ragPipeline, '') + expect(mockUseSysVarValues).toHaveBeenCalledWith(FlowType.ragPipeline, '') + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-nodes-available-var-list.spec.ts b/web/app/components/workflow/hooks/__tests__/use-nodes-available-var-list.spec.ts new file mode 100644 index 0000000000..55db395f2e --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-nodes-available-var-list.spec.ts @@ -0,0 +1,110 @@ +import type { Node, NodeOutPutVar, Var } from '../../types' +import { renderHook } from '@testing-library/react' +import { BlockEnum, VarType } from '../../types' +import useNodesAvailableVarList, { useGetNodesAvailableVarList } from '../use-nodes-available-var-list' + +const mockGetTreeLeafNodes = vi.hoisted(() => vi.fn()) +const mockGetBeforeNodesInSameBranchIncludeParent = vi.hoisted(() => vi.fn()) +const mockGetNodeAvailableVars = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useIsChatMode: () => true, + useWorkflow: () => ({ + getTreeLeafNodes: mockGetTreeLeafNodes, + getBeforeNodesInSameBranchIncludeParent: mockGetBeforeNodesInSameBranchIncludeParent, + }), + useWorkflowVariables: () => ({ + getNodeAvailableVars: mockGetNodeAvailableVars, + }), +})) + +const createNode = (overrides: Partial = {}): Node => ({ + id: 'node-1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.LLM, + title: 'Node', + desc: '', + }, + ...overrides, +} as Node) + +const outputVars: NodeOutPutVar[] = [{ + nodeId: 'vars-node', + title: 'Vars', + vars: [{ + variable: 'name', + type: VarType.string, + }] satisfies Var[], +}] + +describe('useNodesAvailableVarList', () => { + beforeEach(() => { + vi.clearAllMocks() + mockGetBeforeNodesInSameBranchIncludeParent.mockImplementation((nodeId: string) => [createNode({ id: `before-${nodeId}` })]) + mockGetTreeLeafNodes.mockImplementation((nodeId: string) => [createNode({ id: `leaf-${nodeId}` })]) + mockGetNodeAvailableVars.mockReturnValue(outputVars) + }) + + it('builds availability per node, carrying loop nodes and parent iteration context', () => { + const loopNode = createNode({ + id: 'loop-1', + data: { + type: BlockEnum.Loop, + title: 'Loop', + desc: '', + }, + }) + const childNode = createNode({ + id: 'child-1', + parentId: 'loop-1', + data: { + type: BlockEnum.LLM, + title: 'Writer', + desc: '', + }, + }) + const filterVar = vi.fn(() => true) + + const { result } = renderHook(() => useNodesAvailableVarList([loopNode, childNode], { + filterVar, + hideEnv: true, + hideChatVar: true, + })) + + expect(mockGetBeforeNodesInSameBranchIncludeParent).toHaveBeenCalledWith('loop-1') + expect(mockGetBeforeNodesInSameBranchIncludeParent).toHaveBeenCalledWith('child-1') + expect(result.current['loop-1']?.availableNodes.map(node => node.id)).toEqual(['before-loop-1', 'loop-1']) + expect(result.current['child-1']?.availableVars).toBe(outputVars) + expect(mockGetNodeAvailableVars).toHaveBeenNthCalledWith(2, expect.objectContaining({ + parentNode: loopNode, + isChatMode: true, + filterVar, + hideEnv: true, + hideChatVar: true, + })) + }) + + it('returns a callback version that can use leaf nodes or caller-provided nodes', () => { + const firstNode = createNode({ id: 'node-a' }) + const secondNode = createNode({ id: 'node-b' }) + const filterVar = vi.fn(() => true) + const passedInAvailableNodes = [createNode({ id: 'manual-node' })] + + const { result } = renderHook(() => useGetNodesAvailableVarList()) + + const leafMap = result.current.getNodesAvailableVarList([firstNode], { + onlyLeafNodeVar: true, + filterVar, + }) + const manualMap = result.current.getNodesAvailableVarList([secondNode], { + filterVar, + passedInAvailableNodes, + }) + + expect(mockGetTreeLeafNodes).toHaveBeenCalledWith('node-a') + expect(leafMap['node-a']?.availableNodes.map(node => node.id)).toEqual(['leaf-node-a']) + expect(manualMap['node-b']?.availableNodes).toBe(passedInAvailableNodes) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-nodes-interactions-without-sync.spec.ts b/web/app/components/workflow/hooks/__tests__/use-nodes-interactions-without-sync.spec.ts new file mode 100644 index 0000000000..1a2ebe9385 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-nodes-interactions-without-sync.spec.ts @@ -0,0 +1,119 @@ +import { act, waitFor } from '@testing-library/react' +import { useNodes } from 'reactflow' +import { createNode } from '../../__tests__/fixtures' +import { renderWorkflowFlowHook } from '../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../types' +import { useNodesInteractionsWithoutSync } from '../use-nodes-interactions-without-sync' + +type NodeRuntimeState = { + _runningStatus?: NodeRunningStatus + _waitingRun?: boolean +} + +const getNodeRuntimeState = (node?: { data?: unknown }): NodeRuntimeState => + (node?.data ?? {}) as NodeRuntimeState + +const createFlowNodes = () => [ + createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running, _waitingRun: true } }), + createNode({ id: 'n2', position: { x: 100, y: 0 }, data: { _runningStatus: NodeRunningStatus.Succeeded, _waitingRun: false } }), + createNode({ id: 'n3', position: { x: 200, y: 0 }, data: { _runningStatus: NodeRunningStatus.Failed, _waitingRun: true } }), +] + +const renderNodesInteractionsHook = () => + renderWorkflowFlowHook(() => ({ + ...useNodesInteractionsWithoutSync(), + nodes: useNodes(), + }), { + nodes: createFlowNodes(), + edges: [], + }) + +describe('useNodesInteractionsWithoutSync', () => { + it('clears _runningStatus and _waitingRun on all nodes', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleNodeCancelRunningStatus() + }) + + await waitFor(() => { + result.current.nodes.forEach((node) => { + const nodeState = getNodeRuntimeState(node) + expect(nodeState._runningStatus).toBeUndefined() + expect(nodeState._waitingRun).toBe(false) + }) + }) + }) + + it('clears _runningStatus only for Succeeded nodes', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleCancelAllNodeSuccessStatus() + }) + + await waitFor(() => { + const n1 = result.current.nodes.find(node => node.id === 'n1') + const n2 = result.current.nodes.find(node => node.id === 'n2') + const n3 = result.current.nodes.find(node => node.id === 'n3') + + expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) + expect(getNodeRuntimeState(n2)._runningStatus).toBeUndefined() + expect(getNodeRuntimeState(n3)._runningStatus).toBe(NodeRunningStatus.Failed) + }) + }) + + it('does not modify _waitingRun when clearing all success status', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleCancelAllNodeSuccessStatus() + }) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n1'))._waitingRun).toBe(true) + expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n3'))._waitingRun).toBe(true) + }) + }) + + it('clears _runningStatus and _waitingRun for the specified succeeded node', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleCancelNodeSuccessStatus('n2') + }) + + await waitFor(() => { + const n2 = result.current.nodes.find(node => node.id === 'n2') + expect(getNodeRuntimeState(n2)._runningStatus).toBeUndefined() + expect(getNodeRuntimeState(n2)._waitingRun).toBe(false) + }) + }) + + it('does not modify nodes that are not succeeded', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleCancelNodeSuccessStatus('n1') + }) + + await waitFor(() => { + const n1 = result.current.nodes.find(node => node.id === 'n1') + expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) + expect(getNodeRuntimeState(n1)._waitingRun).toBe(true) + }) + }) + + it('does not modify other nodes', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleCancelNodeSuccessStatus('n2') + }) + + await waitFor(() => { + const n1 = result.current.nodes.find(node => node.id === 'n1') + expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-nodes-interactions.spec.ts b/web/app/components/workflow/hooks/__tests__/use-nodes-interactions.spec.ts new file mode 100644 index 0000000000..35a309902e --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-nodes-interactions.spec.ts @@ -0,0 +1,205 @@ +import type { Edge, Node } from '../../types' +import { act } from '@testing-library/react' +import { createEdge, createNode } from '../../__tests__/fixtures' +import { resetReactFlowMockState, rfState } from '../../__tests__/reactflow-mock-state' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum } from '../../types' +import { useNodesInteractions } from '../use-nodes-interactions' + +const mockHandleSyncWorkflowDraft = vi.hoisted(() => vi.fn()) +const mockSaveStateToHistory = vi.hoisted(() => vi.fn()) +const mockUndo = vi.hoisted(() => vi.fn()) +const mockRedo = vi.hoisted(() => vi.fn()) + +const runtimeState = vi.hoisted(() => ({ + nodesReadOnly: false, + workflowReadOnly: false, +})) + +let currentNodes: Node[] = [] +let currentEdges: Edge[] = [] + +vi.mock('reactflow', async () => + (await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock()) + +vi.mock('../use-workflow', () => ({ + useWorkflow: () => ({ + getAfterNodesInSameBranch: () => [], + }), + useNodesReadOnly: () => ({ + getNodesReadOnly: () => runtimeState.nodesReadOnly, + }), + useWorkflowReadOnly: () => ({ + getWorkflowReadOnly: () => runtimeState.workflowReadOnly, + }), +})) + +vi.mock('../use-helpline', () => ({ + useHelpline: () => ({ + handleSetHelpline: () => ({ + showHorizontalHelpLineNodes: [], + showVerticalHelpLineNodes: [], + }), + }), +})) + +vi.mock('../use-nodes-meta-data', () => ({ + useNodesMetaData: () => ({ + nodesMap: {}, + }), +})) + +vi.mock('../use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ + handleSyncWorkflowDraft: mockHandleSyncWorkflowDraft, + }), +})) + +vi.mock('../use-auto-generate-webhook-url', () => ({ + useAutoGenerateWebhookUrl: () => vi.fn(), +})) + +vi.mock('../use-inspect-vars-crud', () => ({ + default: () => ({ + deleteNodeInspectorVars: vi.fn(), + }), +})) + +vi.mock('../../nodes/iteration/use-interactions', () => ({ + useNodeIterationInteractions: () => ({ + handleNodeIterationChildDrag: () => ({ restrictPosition: {} }), + handleNodeIterationChildrenCopy: vi.fn(), + }), +})) + +vi.mock('../../nodes/loop/use-interactions', () => ({ + useNodeLoopInteractions: () => ({ + handleNodeLoopChildDrag: () => ({ restrictPosition: {} }), + handleNodeLoopChildrenCopy: vi.fn(), + }), +})) + +vi.mock('../use-workflow-history', async importOriginal => ({ + ...(await importOriginal()), + useWorkflowHistory: () => ({ + saveStateToHistory: mockSaveStateToHistory, + undo: mockUndo, + redo: mockRedo, + }), +})) + +describe('useNodesInteractions', () => { + beforeEach(() => { + vi.clearAllMocks() + resetReactFlowMockState() + runtimeState.nodesReadOnly = false + runtimeState.workflowReadOnly = false + currentNodes = [ + createNode({ + id: 'node-1', + position: { x: 10, y: 20 }, + data: { + type: BlockEnum.Code, + title: 'Code', + desc: '', + }, + }), + ] + currentEdges = [ + createEdge({ + id: 'edge-1', + source: 'node-1', + target: 'node-2', + }), + ] + rfState.nodes = currentNodes as unknown as typeof rfState.nodes + rfState.edges = currentEdges as unknown as typeof rfState.edges + }) + + it('persists node drags only when the node position actually changes', () => { + const node = currentNodes[0] + const movedNode = { + ...node, + position: { x: 120, y: 80 }, + } + + const { result, store } = renderWorkflowHook(() => useNodesInteractions(), { + historyStore: { + nodes: currentNodes, + edges: currentEdges, + }, + }) + + act(() => { + result.current.handleNodeDragStart({} as never, node, currentNodes) + result.current.handleNodeDragStop({} as never, movedNode, currentNodes) + }) + + expect(store.getState().nodeAnimation).toBe(false) + expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledTimes(1) + expect(mockSaveStateToHistory).toHaveBeenCalledWith('NodeDragStop', { + nodeId: 'node-1', + }) + }) + + it('restores history snapshots on undo and clears the edge menu', () => { + const historyNodes = [ + createNode({ + id: 'history-node', + data: { + type: BlockEnum.End, + title: 'End', + desc: '', + }, + }), + ] + const historyEdges = [ + createEdge({ + id: 'history-edge', + source: 'history-node', + target: 'node-1', + }), + ] + + const { result, store } = renderWorkflowHook(() => useNodesInteractions(), { + initialStoreState: { + edgeMenu: { + id: 'edge-1', + } as never, + }, + historyStore: { + nodes: historyNodes, + edges: historyEdges, + }, + }) + + act(() => { + result.current.handleHistoryBack() + }) + + expect(mockUndo).toHaveBeenCalledTimes(1) + expect(rfState.setNodes).toHaveBeenCalledWith(historyNodes) + expect(rfState.setEdges).toHaveBeenCalledWith(historyEdges) + expect(store.getState().edgeMenu).toBeUndefined() + }) + + it('skips undo and redo when the workflow is read-only', () => { + runtimeState.workflowReadOnly = true + const { result } = renderWorkflowHook(() => useNodesInteractions(), { + historyStore: { + nodes: currentNodes, + edges: currentEdges, + }, + }) + + act(() => { + result.current.handleHistoryBack() + result.current.handleHistoryForward() + }) + + expect(mockUndo).not.toHaveBeenCalled() + expect(mockRedo).not.toHaveBeenCalled() + expect(rfState.setNodes).not.toHaveBeenCalled() + expect(rfState.setEdges).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-nodes-meta-data.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-nodes-meta-data.spec.tsx new file mode 100644 index 0000000000..9dffa46cb2 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-nodes-meta-data.spec.tsx @@ -0,0 +1,153 @@ +import type { Node } from '../../types' +import { CollectionType } from '@/app/components/tools/types' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum } from '../../types' +import { useNodeMetaData, useNodesMetaData } from '../use-nodes-meta-data' + +const buildInToolsState = vi.hoisted(() => [] as Array<{ id: string, author: string, description: Record }>) +const customToolsState = vi.hoisted(() => [] as Array<{ id: string, author: string, description: Record }>) +const workflowToolsState = vi.hoisted(() => [] as Array<{ id: string, author: string, description: Record }>) + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: () => 'en-US', +})) + +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => ({ data: buildInToolsState }), + useAllCustomTools: () => ({ data: customToolsState }), + useAllWorkflowTools: () => ({ data: workflowToolsState }), +})) + +const createNode = (overrides: Partial = {}): Node => ({ + id: 'node-1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.LLM, + title: 'Node', + desc: '', + }, + ...overrides, +} as Node) + +describe('useNodesMetaData', () => { + beforeEach(() => { + vi.clearAllMocks() + buildInToolsState.length = 0 + customToolsState.length = 0 + workflowToolsState.length = 0 + }) + + it('returns empty metadata collections when the hooks store has no node map', () => { + const { result } = renderWorkflowHook(() => useNodesMetaData(), { + hooksStoreProps: {}, + }) + + expect(result.current).toEqual({ + nodes: [], + nodesMap: {}, + }) + }) + + it('resolves built-in tool metadata from tool providers', () => { + buildInToolsState.push({ + id: 'provider-1', + author: 'Provider Author', + description: { + 'en-US': 'Built-in provider description', + }, + }) + + const toolNode = createNode({ + data: { + type: BlockEnum.Tool, + title: 'Tool Node', + desc: '', + provider_type: CollectionType.builtIn, + provider_id: 'provider-1', + }, + }) + + const { result } = renderWorkflowHook(() => useNodeMetaData(toolNode), { + hooksStoreProps: { + availableNodesMetaData: { + nodes: [], + }, + }, + }) + + expect(result.current).toEqual(expect.objectContaining({ + author: 'Provider Author', + description: 'Built-in provider description', + })) + }) + + it('prefers workflow store data for datasource nodes and keeps generic metadata for normal blocks', () => { + const datasourceNode = createNode({ + data: { + type: BlockEnum.DataSource, + title: 'Dataset', + desc: '', + plugin_id: 'datasource-1', + }, + }) + + const normalNode = createNode({ + data: { + type: BlockEnum.LLM, + title: 'Writer', + desc: '', + }, + }) + + const datasource = { + plugin_id: 'datasource-1', + author: 'Datasource Author', + description: { + 'en-US': 'Datasource description', + }, + } + + const metadataMap = { + [BlockEnum.LLM]: { + metaData: { + type: BlockEnum.LLM, + title: 'LLM', + author: 'Dify', + description: 'Node description', + }, + }, + } + + const datasourceResult = renderWorkflowHook(() => useNodeMetaData(datasourceNode), { + initialStoreState: { + dataSourceList: [datasource as never], + }, + hooksStoreProps: { + availableNodesMetaData: { + nodes: [], + nodesMap: metadataMap as never, + }, + }, + }) + + const normalResult = renderWorkflowHook(() => useNodeMetaData(normalNode), { + hooksStoreProps: { + availableNodesMetaData: { + nodes: [], + nodesMap: metadataMap as never, + }, + }, + }) + + expect(datasourceResult.result.current).toEqual(expect.objectContaining({ + author: 'Datasource Author', + description: 'Datasource description', + })) + expect(normalResult.result.current).toEqual(expect.objectContaining({ + author: 'Dify', + description: 'Node description', + title: 'LLM', + })) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-set-workflow-vars-with-value.spec.ts b/web/app/components/workflow/hooks/__tests__/use-set-workflow-vars-with-value.spec.ts new file mode 100644 index 0000000000..c0d693cf24 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-set-workflow-vars-with-value.spec.ts @@ -0,0 +1,14 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useSetWorkflowVarsWithValue } from '../use-set-workflow-vars-with-value' + +describe('useSetWorkflowVarsWithValue', () => { + it('returns fetchInspectVars from hooks store', () => { + const fetchInspectVars = vi.fn() + + const { result } = renderWorkflowHook(() => useSetWorkflowVarsWithValue(), { + hooksStoreProps: { fetchInspectVars }, + }) + + expect(result.current.fetchInspectVars).toBe(fetchInspectVars) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-shortcuts.spec.ts b/web/app/components/workflow/hooks/__tests__/use-shortcuts.spec.ts new file mode 100644 index 0000000000..b3c63ff519 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-shortcuts.spec.ts @@ -0,0 +1,168 @@ +import { act } from '@testing-library/react' +import { ZEN_TOGGLE_EVENT } from '@/app/components/goto-anything/actions/commands/zen' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useShortcuts } from '../use-shortcuts' + +type KeyPressRegistration = { + keyFilter: unknown + handler: (event: KeyboardEvent) => void + options?: { + events?: string[] + } +} + +const keyPressRegistrations = vi.hoisted(() => []) +const mockZoomTo = vi.hoisted(() => vi.fn()) +const mockGetZoom = vi.hoisted(() => vi.fn(() => 1)) +const mockFitView = vi.hoisted(() => vi.fn()) +const mockHandleNodesDelete = vi.hoisted(() => vi.fn()) +const mockHandleEdgeDelete = vi.hoisted(() => vi.fn()) +const mockHandleNodesCopy = vi.hoisted(() => vi.fn()) +const mockHandleNodesPaste = vi.hoisted(() => vi.fn()) +const mockHandleNodesDuplicate = vi.hoisted(() => vi.fn()) +const mockHandleHistoryBack = vi.hoisted(() => vi.fn()) +const mockHandleHistoryForward = vi.hoisted(() => vi.fn()) +const mockDimOtherNodes = vi.hoisted(() => vi.fn()) +const mockUndimAllNodes = vi.hoisted(() => vi.fn()) +const mockHandleSyncWorkflowDraft = vi.hoisted(() => vi.fn()) +const mockHandleModeHand = vi.hoisted(() => vi.fn()) +const mockHandleModePointer = vi.hoisted(() => vi.fn()) +const mockHandleLayout = vi.hoisted(() => vi.fn()) +const mockHandleToggleMaximizeCanvas = vi.hoisted(() => vi.fn()) + +vi.mock('ahooks', () => ({ + useKeyPress: (keyFilter: unknown, handler: (event: KeyboardEvent) => void, options?: { events?: string[] }) => { + keyPressRegistrations.push({ keyFilter, handler, options }) + }, +})) + +vi.mock('reactflow', () => ({ + useReactFlow: () => ({ + zoomTo: mockZoomTo, + getZoom: mockGetZoom, + fitView: mockFitView, + }), +})) + +vi.mock('..', () => ({ + useNodesInteractions: () => ({ + handleNodesCopy: mockHandleNodesCopy, + handleNodesPaste: mockHandleNodesPaste, + handleNodesDuplicate: mockHandleNodesDuplicate, + handleNodesDelete: mockHandleNodesDelete, + handleHistoryBack: mockHandleHistoryBack, + handleHistoryForward: mockHandleHistoryForward, + dimOtherNodes: mockDimOtherNodes, + undimAllNodes: mockUndimAllNodes, + }), + useEdgesInteractions: () => ({ + handleEdgeDelete: mockHandleEdgeDelete, + }), + useNodesSyncDraft: () => ({ + handleSyncWorkflowDraft: mockHandleSyncWorkflowDraft, + }), + useWorkflowCanvasMaximize: () => ({ + handleToggleMaximizeCanvas: mockHandleToggleMaximizeCanvas, + }), + useWorkflowMoveMode: () => ({ + handleModeHand: mockHandleModeHand, + handleModePointer: mockHandleModePointer, + }), + useWorkflowOrganize: () => ({ + handleLayout: mockHandleLayout, + }), +})) + +vi.mock('../../workflow-history-store', () => ({ + useWorkflowHistoryStore: () => ({ + shortcutsEnabled: true, + }), +})) + +const createKeyboardEvent = (target: HTMLElement = document.body) => ({ + preventDefault: vi.fn(), + target, +}) as unknown as KeyboardEvent + +const findRegistration = (matcher: (registration: KeyPressRegistration) => boolean) => { + const registration = keyPressRegistrations.find(matcher) + expect(registration).toBeDefined() + return registration as KeyPressRegistration +} + +describe('useShortcuts', () => { + beforeEach(() => { + keyPressRegistrations.length = 0 + vi.clearAllMocks() + }) + + it('deletes selected nodes and edges only outside editable inputs', () => { + renderWorkflowHook(() => useShortcuts()) + + const deleteShortcut = findRegistration(registration => + Array.isArray(registration.keyFilter) + && registration.keyFilter.includes('delete'), + ) + + const bodyEvent = createKeyboardEvent() + deleteShortcut.handler(bodyEvent) + + expect(bodyEvent.preventDefault).toHaveBeenCalled() + expect(mockHandleNodesDelete).toHaveBeenCalledTimes(1) + expect(mockHandleEdgeDelete).toHaveBeenCalledTimes(1) + + const inputEvent = createKeyboardEvent(document.createElement('input')) + deleteShortcut.handler(inputEvent) + + expect(mockHandleNodesDelete).toHaveBeenCalledTimes(1) + expect(mockHandleEdgeDelete).toHaveBeenCalledTimes(1) + }) + + it('runs layout and zoom shortcuts through the workflow actions', () => { + renderWorkflowHook(() => useShortcuts()) + + const layoutShortcut = findRegistration(registration => registration.keyFilter === 'ctrl.o' || registration.keyFilter === 'meta.o') + const fitViewShortcut = findRegistration(registration => registration.keyFilter === 'ctrl.1' || registration.keyFilter === 'meta.1') + const halfZoomShortcut = findRegistration(registration => registration.keyFilter === 'shift.5') + const zoomOutShortcut = findRegistration(registration => registration.keyFilter === 'ctrl.dash' || registration.keyFilter === 'meta.dash') + const zoomInShortcut = findRegistration(registration => registration.keyFilter === 'ctrl.equalsign' || registration.keyFilter === 'meta.equalsign') + + layoutShortcut.handler(createKeyboardEvent()) + fitViewShortcut.handler(createKeyboardEvent()) + halfZoomShortcut.handler(createKeyboardEvent()) + zoomOutShortcut.handler(createKeyboardEvent()) + zoomInShortcut.handler(createKeyboardEvent()) + + expect(mockHandleLayout).toHaveBeenCalledTimes(1) + expect(mockFitView).toHaveBeenCalledTimes(1) + expect(mockZoomTo).toHaveBeenNthCalledWith(1, 0.5) + expect(mockZoomTo).toHaveBeenNthCalledWith(2, 0.9) + expect(mockZoomTo).toHaveBeenNthCalledWith(3, 1.1) + expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledTimes(4) + }) + + it('dims on shift down, undims on shift up, and responds to zen toggle events', () => { + const { unmount } = renderWorkflowHook(() => useShortcuts()) + + const shiftDownShortcut = findRegistration(registration => registration.keyFilter === 'shift' && registration.options?.events?.[0] === 'keydown') + const shiftUpShortcut = findRegistration(registration => typeof registration.keyFilter === 'function' && registration.options?.events?.[0] === 'keyup') + + shiftDownShortcut.handler(createKeyboardEvent()) + shiftUpShortcut.handler({ ...createKeyboardEvent(), key: 'Shift' } as KeyboardEvent) + + expect(mockDimOtherNodes).toHaveBeenCalledTimes(1) + expect(mockUndimAllNodes).toHaveBeenCalledTimes(1) + + act(() => { + window.dispatchEvent(new Event(ZEN_TOGGLE_EVENT)) + }) + expect(mockHandleToggleMaximizeCanvas).toHaveBeenCalledTimes(1) + + unmount() + + act(() => { + window.dispatchEvent(new Event(ZEN_TOGGLE_EVENT)) + }) + expect(mockHandleToggleMaximizeCanvas).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-without-sync-hooks.spec.ts b/web/app/components/workflow/hooks/__tests__/use-without-sync-hooks.spec.ts deleted file mode 100644 index 2d40028226..0000000000 --- a/web/app/components/workflow/hooks/__tests__/use-without-sync-hooks.spec.ts +++ /dev/null @@ -1,209 +0,0 @@ -import { act, waitFor } from '@testing-library/react' -import { useEdges, useNodes } from 'reactflow' -import { createEdge, createNode } from '../../__tests__/fixtures' -import { renderWorkflowFlowHook } from '../../__tests__/workflow-test-env' -import { NodeRunningStatus } from '../../types' -import { useEdgesInteractionsWithoutSync } from '../use-edges-interactions-without-sync' -import { useNodesInteractionsWithoutSync } from '../use-nodes-interactions-without-sync' - -type EdgeRuntimeState = { - _sourceRunningStatus?: NodeRunningStatus - _targetRunningStatus?: NodeRunningStatus - _waitingRun?: boolean -} - -type NodeRuntimeState = { - _runningStatus?: NodeRunningStatus - _waitingRun?: boolean -} - -const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState => - (edge?.data ?? {}) as EdgeRuntimeState - -const getNodeRuntimeState = (node?: { data?: unknown }): NodeRuntimeState => - (node?.data ?? {}) as NodeRuntimeState - -describe('useEdgesInteractionsWithoutSync', () => { - const createFlowNodes = () => [ - createNode({ id: 'a' }), - createNode({ id: 'b' }), - createNode({ id: 'c' }), - ] - const createFlowEdges = () => [ - createEdge({ - id: 'e1', - source: 'a', - target: 'b', - data: { - _sourceRunningStatus: NodeRunningStatus.Running, - _targetRunningStatus: NodeRunningStatus.Running, - _waitingRun: true, - }, - }), - createEdge({ - id: 'e2', - source: 'b', - target: 'c', - data: { - _sourceRunningStatus: NodeRunningStatus.Succeeded, - _targetRunningStatus: undefined, - _waitingRun: false, - }, - }), - ] - - const renderEdgesInteractionsHook = () => - renderWorkflowFlowHook(() => ({ - ...useEdgesInteractionsWithoutSync(), - edges: useEdges(), - }), { - nodes: createFlowNodes(), - edges: createFlowEdges(), - }) - - it('should clear running status and waitingRun on all edges', () => { - const { result } = renderEdgesInteractionsHook() - - act(() => { - result.current.handleEdgeCancelRunningStatus() - }) - - return waitFor(() => { - result.current.edges.forEach((edge) => { - const edgeState = getEdgeRuntimeState(edge) - expect(edgeState._sourceRunningStatus).toBeUndefined() - expect(edgeState._targetRunningStatus).toBeUndefined() - expect(edgeState._waitingRun).toBe(false) - }) - }) - }) - - it('should not mutate original edges', () => { - const edges = createFlowEdges() - const originalData = { ...getEdgeRuntimeState(edges[0]) } - const { result } = renderWorkflowFlowHook(() => ({ - ...useEdgesInteractionsWithoutSync(), - edges: useEdges(), - }), { - nodes: createFlowNodes(), - edges, - }) - - act(() => { - result.current.handleEdgeCancelRunningStatus() - }) - - expect(getEdgeRuntimeState(edges[0])._sourceRunningStatus).toBe(originalData._sourceRunningStatus) - }) -}) - -describe('useNodesInteractionsWithoutSync', () => { - const createFlowNodes = () => [ - createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running, _waitingRun: true } }), - createNode({ id: 'n2', position: { x: 100, y: 0 }, data: { _runningStatus: NodeRunningStatus.Succeeded, _waitingRun: false } }), - createNode({ id: 'n3', position: { x: 200, y: 0 }, data: { _runningStatus: NodeRunningStatus.Failed, _waitingRun: true } }), - ] - - const renderNodesInteractionsHook = () => - renderWorkflowFlowHook(() => ({ - ...useNodesInteractionsWithoutSync(), - nodes: useNodes(), - }), { - nodes: createFlowNodes(), - edges: [], - }) - - describe('handleNodeCancelRunningStatus', () => { - it('should clear _runningStatus and _waitingRun on all nodes', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleNodeCancelRunningStatus() - }) - - await waitFor(() => { - result.current.nodes.forEach((node) => { - const nodeState = getNodeRuntimeState(node) - expect(nodeState._runningStatus).toBeUndefined() - expect(nodeState._waitingRun).toBe(false) - }) - }) - }) - }) - - describe('handleCancelAllNodeSuccessStatus', () => { - it('should clear _runningStatus only for Succeeded nodes', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleCancelAllNodeSuccessStatus() - }) - - await waitFor(() => { - const n1 = result.current.nodes.find(node => node.id === 'n1') - const n2 = result.current.nodes.find(node => node.id === 'n2') - const n3 = result.current.nodes.find(node => node.id === 'n3') - - expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) - expect(getNodeRuntimeState(n2)._runningStatus).toBeUndefined() - expect(getNodeRuntimeState(n3)._runningStatus).toBe(NodeRunningStatus.Failed) - }) - }) - - it('should not modify _waitingRun', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleCancelAllNodeSuccessStatus() - }) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n1'))._waitingRun).toBe(true) - expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n3'))._waitingRun).toBe(true) - }) - }) - }) - - describe('handleCancelNodeSuccessStatus', () => { - it('should clear _runningStatus and _waitingRun for the specified Succeeded node', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleCancelNodeSuccessStatus('n2') - }) - - await waitFor(() => { - const n2 = result.current.nodes.find(node => node.id === 'n2') - expect(getNodeRuntimeState(n2)._runningStatus).toBeUndefined() - expect(getNodeRuntimeState(n2)._waitingRun).toBe(false) - }) - }) - - it('should not modify nodes that are not Succeeded', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleCancelNodeSuccessStatus('n1') - }) - - await waitFor(() => { - const n1 = result.current.nodes.find(node => node.id === 'n1') - expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) - expect(getNodeRuntimeState(n1)._waitingRun).toBe(true) - }) - }) - - it('should not modify other nodes', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleCancelNodeSuccessStatus('n2') - }) - - await waitFor(() => { - const n1 = result.current.nodes.find(node => node.id === 'n1') - expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) - }) - }) - }) -}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-canvas-maximize.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-canvas-maximize.spec.ts new file mode 100644 index 0000000000..f4cde1e72a --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-canvas-maximize.spec.ts @@ -0,0 +1,59 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { WorkflowRunningStatus } from '../../types' +import { useWorkflowCanvasMaximize } from '../use-workflow-canvas-maximize' + +const mockEmit = vi.hoisted(() => vi.fn()) + +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + emit: mockEmit, + }, + }), +})) + +describe('useWorkflowCanvasMaximize', () => { + beforeEach(() => { + vi.clearAllMocks() + localStorage.clear() + }) + + it('toggles maximize state, persists it, and emits the canvas event', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowCanvasMaximize(), { + initialStoreState: { + maximizeCanvas: false, + }, + }) + + result.current.handleToggleMaximizeCanvas() + + expect(store.getState().maximizeCanvas).toBe(true) + expect(localStorage.getItem('workflow-canvas-maximize')).toBe('true') + expect(mockEmit).toHaveBeenCalledWith({ + type: 'workflow-canvas-maximize', + payload: true, + }) + }) + + it('does nothing while workflow nodes are read-only', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowCanvasMaximize(), { + initialStoreState: { + maximizeCanvas: false, + workflowRunningData: { + result: { + status: WorkflowRunningStatus.Running, + inputs_truncated: false, + process_data_truncated: false, + outputs_truncated: false, + }, + }, + }, + }) + + result.current.handleToggleMaximizeCanvas() + + expect(store.getState().maximizeCanvas).toBe(false) + expect(localStorage.getItem('workflow-canvas-maximize')).toBeNull() + expect(mockEmit).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-history.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-history.spec.tsx new file mode 100644 index 0000000000..54917d009c --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-history.spec.tsx @@ -0,0 +1,141 @@ +import type { Edge, Node } from '../../types' +import { act } from '@testing-library/react' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum } from '../../types' +import { useWorkflowHistory, WorkflowHistoryEvent } from '../use-workflow-history' + +const reactFlowState = vi.hoisted(() => ({ + edges: [] as Edge[], + nodes: [] as Node[], +})) + +vi.mock('es-toolkit/compat', () => ({ + debounce: unknown>(fn: T) => fn, +})) + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useStoreApi: () => ({ + getState: () => ({ + getNodes: () => reactFlowState.nodes, + edges: reactFlowState.edges, + }), + }), + } +}) + +vi.mock('react-i18next', async () => { + const actual = await vi.importActual('react-i18next') + return { + ...actual, + useTranslation: () => ({ + t: (key: string) => key, + }), + } +}) + +const nodes: Node[] = [{ + id: 'node-1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.Start, + title: 'Start', + desc: '', + }, +}] + +const edges: Edge[] = [{ + id: 'edge-1', + source: 'node-1', + target: 'node-2', + type: 'custom', + data: { + sourceType: BlockEnum.Start, + targetType: BlockEnum.End, + }, +}] + +describe('useWorkflowHistory', () => { + beforeEach(() => { + reactFlowState.nodes = nodes + reactFlowState.edges = edges + }) + + it('stores the latest workflow graph snapshot for supported events', () => { + const { result } = renderWorkflowHook(() => useWorkflowHistory(), { + historyStore: { + nodes, + edges, + }, + }) + + act(() => { + result.current.saveStateToHistory(WorkflowHistoryEvent.NodeAdd, { nodeId: 'node-1' }) + }) + + expect(result.current.store.getState().workflowHistoryEvent).toBe(WorkflowHistoryEvent.NodeAdd) + expect(result.current.store.getState().workflowHistoryEventMeta).toEqual({ nodeId: 'node-1' }) + expect(result.current.store.getState().nodes).toEqual([ + expect.objectContaining({ + id: 'node-1', + data: expect.objectContaining({ + selected: false, + title: 'Start', + }), + }), + ]) + expect(result.current.store.getState().edges).toEqual([ + expect.objectContaining({ + id: 'edge-1', + selected: false, + source: 'node-1', + target: 'node-2', + }), + ]) + }) + + it('returns translated labels and falls back for unsupported events', () => { + const { result } = renderWorkflowHook(() => useWorkflowHistory(), { + historyStore: { + nodes, + edges, + }, + }) + + expect(result.current.getHistoryLabel(WorkflowHistoryEvent.NodeDelete)).toBe('changeHistory.nodeDelete') + expect(result.current.getHistoryLabel('Unknown' as keyof typeof WorkflowHistoryEvent)).toBe('Unknown Event') + }) + + it('runs registered undo and redo callbacks', () => { + const onUndo = vi.fn() + const onRedo = vi.fn() + + const { result } = renderWorkflowHook(() => useWorkflowHistory(), { + historyStore: { + nodes, + edges, + }, + }) + + act(() => { + result.current.onUndo(onUndo) + result.current.onRedo(onRedo) + }) + + const undoSpy = vi.spyOn(result.current.store.temporal.getState(), 'undo') + const redoSpy = vi.spyOn(result.current.store.temporal.getState(), 'redo') + + act(() => { + result.current.undo() + result.current.redo() + }) + + expect(undoSpy).toHaveBeenCalled() + expect(redoSpy).toHaveBeenCalled() + expect(onUndo).toHaveBeenCalled() + expect(onRedo).toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-organize.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-organize.spec.tsx new file mode 100644 index 0000000000..424ad96630 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-organize.spec.tsx @@ -0,0 +1,152 @@ +import { act } from '@testing-library/react' +import { createLoopNode, createNode } from '../../__tests__/fixtures' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useWorkflowOrganize } from '../use-workflow-organize' + +const mockSetViewport = vi.hoisted(() => vi.fn()) +const mockSetNodes = vi.hoisted(() => vi.fn()) +const mockHandleSyncWorkflowDraft = vi.hoisted(() => vi.fn()) +const mockSaveStateToHistory = vi.hoisted(() => vi.fn()) +const mockGetLayoutForChildNodes = vi.hoisted(() => vi.fn()) +const mockGetLayoutByELK = vi.hoisted(() => vi.fn()) + +const runtimeState = vi.hoisted(() => ({ + nodes: [] as ReturnType[], + edges: [] as { id: string, source: string, target: string }[], + nodesReadOnly: false, +})) + +vi.mock('reactflow', () => ({ + Position: { + Left: 'left', + Right: 'right', + Top: 'top', + Bottom: 'bottom', + }, + useStoreApi: () => ({ + getState: () => ({ + getNodes: () => runtimeState.nodes, + edges: runtimeState.edges, + setNodes: mockSetNodes, + }), + setState: vi.fn(), + }), + useReactFlow: () => ({ + setViewport: mockSetViewport, + }), +})) + +vi.mock('../use-workflow', () => ({ + useNodesReadOnly: () => ({ + getNodesReadOnly: () => runtimeState.nodesReadOnly, + nodesReadOnly: runtimeState.nodesReadOnly, + }), +})) + +vi.mock('../use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ + handleSyncWorkflowDraft: (...args: unknown[]) => mockHandleSyncWorkflowDraft(...args), + }), +})) + +vi.mock('../use-workflow-history', () => ({ + useWorkflowHistory: () => ({ + saveStateToHistory: (...args: unknown[]) => mockSaveStateToHistory(...args), + }), + WorkflowHistoryEvent: { + LayoutOrganize: 'LayoutOrganize', + }, +})) + +vi.mock('../../utils/elk-layout', async importOriginal => ({ + ...(await importOriginal()), + getLayoutForChildNodes: (...args: unknown[]) => mockGetLayoutForChildNodes(...args), + getLayoutByELK: (...args: unknown[]) => mockGetLayoutByELK(...args), +})) + +describe('useWorkflowOrganize', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.useFakeTimers() + runtimeState.nodesReadOnly = false + runtimeState.nodes = [] + runtimeState.edges = [] + }) + + afterEach(() => { + vi.useRealTimers() + }) + + it('resizes containers, lays out nodes, and syncs draft when editable', async () => { + runtimeState.nodes = [ + createLoopNode({ + id: 'loop-node', + width: 200, + height: 160, + }), + createNode({ + id: 'loop-child', + parentId: 'loop-node', + position: { x: 20, y: 20 }, + width: 100, + height: 60, + }), + createNode({ + id: 'top-node', + position: { x: 400, y: 0 }, + }), + ] + runtimeState.edges = [] + mockGetLayoutForChildNodes.mockResolvedValue({ + bounds: { minX: 0, minY: 0, maxX: 320, maxY: 220 }, + nodes: new Map([ + ['loop-child', { x: 40, y: 60, width: 100, height: 60 }], + ]), + }) + mockGetLayoutByELK.mockResolvedValue({ + nodes: new Map([ + ['loop-node', { x: 10, y: 20, width: 360, height: 260, layer: 0 }], + ['top-node', { x: 500, y: 30, width: 240, height: 100, layer: 0 }], + ]), + }) + + const { result } = renderWorkflowHook(() => useWorkflowOrganize()) + + await act(async () => { + await result.current.handleLayout() + }) + act(() => { + vi.runAllTimers() + }) + + expect(mockSetNodes).toHaveBeenCalledTimes(1) + const nextNodes = mockSetNodes.mock.calls[0][0] + expect(nextNodes.find((node: { id: string }) => node.id === 'loop-node')).toEqual(expect.objectContaining({ + width: expect.any(Number), + height: expect.any(Number), + position: { x: 10, y: 20 }, + })) + expect(nextNodes.find((node: { id: string }) => node.id === 'loop-child')).toEqual(expect.objectContaining({ + position: { x: 100, y: 120 }, + })) + expect(mockSetViewport).toHaveBeenCalledWith({ x: 0, y: 0, zoom: 0.7 }) + expect(mockSaveStateToHistory).toHaveBeenCalledWith('LayoutOrganize') + expect(mockHandleSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('skips layout when nodes are read-only', async () => { + runtimeState.nodesReadOnly = true + runtimeState.nodes = [createNode({ id: 'n1' })] + + const { result } = renderWorkflowHook(() => useWorkflowOrganize()) + + await act(async () => { + await result.current.handleLayout() + }) + + expect(mockGetLayoutForChildNodes).not.toHaveBeenCalled() + expect(mockGetLayoutByELK).not.toHaveBeenCalled() + expect(mockSetNodes).not.toHaveBeenCalled() + expect(mockSetViewport).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-panel-interactions.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-panel-interactions.spec.tsx new file mode 100644 index 0000000000..9ff61f70f9 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-panel-interactions.spec.tsx @@ -0,0 +1,110 @@ +import { act } from '@testing-library/react' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { ControlMode } from '../../types' +import { + useWorkflowInteractions, + useWorkflowMoveMode, +} from '../use-workflow-panel-interactions' + +const mockHandleSelectionCancel = vi.hoisted(() => vi.fn()) +const mockHandleNodeCancelRunningStatus = vi.hoisted(() => vi.fn()) +const mockHandleEdgeCancelRunningStatus = vi.hoisted(() => vi.fn()) + +const runtimeState = vi.hoisted(() => ({ + nodesReadOnly: false, +})) + +vi.mock('../use-workflow', () => ({ + useNodesReadOnly: () => ({ + getNodesReadOnly: () => runtimeState.nodesReadOnly, + nodesReadOnly: runtimeState.nodesReadOnly, + }), +})) + +vi.mock('../use-selection-interactions', () => ({ + useSelectionInteractions: () => ({ + handleSelectionCancel: (...args: unknown[]) => mockHandleSelectionCancel(...args), + }), +})) + +vi.mock('../use-nodes-interactions-without-sync', () => ({ + useNodesInteractionsWithoutSync: () => ({ + handleNodeCancelRunningStatus: (...args: unknown[]) => mockHandleNodeCancelRunningStatus(...args), + }), +})) + +vi.mock('../use-edges-interactions-without-sync', () => ({ + useEdgesInteractionsWithoutSync: () => ({ + handleEdgeCancelRunningStatus: (...args: unknown[]) => mockHandleEdgeCancelRunningStatus(...args), + }), +})) + +describe('useWorkflowInteractions', () => { + beforeEach(() => { + vi.clearAllMocks() + runtimeState.nodesReadOnly = false + }) + + it('closes the debug panel and clears running state', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowInteractions(), { + initialStoreState: { + showDebugAndPreviewPanel: true, + workflowRunningData: { task_id: 'task-1' } as never, + }, + }) + + act(() => { + result.current.handleCancelDebugAndPreviewPanel() + }) + + expect(store.getState().showDebugAndPreviewPanel).toBe(false) + expect(store.getState().workflowRunningData).toBeUndefined() + expect(mockHandleNodeCancelRunningStatus).toHaveBeenCalledTimes(1) + expect(mockHandleEdgeCancelRunningStatus).toHaveBeenCalledTimes(1) + }) +}) + +describe('useWorkflowMoveMode', () => { + beforeEach(() => { + vi.clearAllMocks() + runtimeState.nodesReadOnly = false + }) + + it('switches between hand and pointer modes when editable', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowMoveMode(), { + initialStoreState: { + controlMode: ControlMode.Pointer, + }, + }) + + act(() => { + result.current.handleModeHand() + }) + + expect(store.getState().controlMode).toBe(ControlMode.Hand) + expect(mockHandleSelectionCancel).toHaveBeenCalledTimes(1) + + act(() => { + result.current.handleModePointer() + }) + + expect(store.getState().controlMode).toBe(ControlMode.Pointer) + }) + + it('does not switch modes when nodes are read-only', () => { + runtimeState.nodesReadOnly = true + const { result, store } = renderWorkflowHook(() => useWorkflowMoveMode(), { + initialStoreState: { + controlMode: ControlMode.Pointer, + }, + }) + + act(() => { + result.current.handleModeHand() + result.current.handleModePointer() + }) + + expect(store.getState().controlMode).toBe(ControlMode.Pointer) + expect(mockHandleSelectionCancel).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-refresh-draft.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-refresh-draft.spec.ts new file mode 100644 index 0000000000..83c8a4199b --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-refresh-draft.spec.ts @@ -0,0 +1,14 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useWorkflowRefreshDraft } from '../use-workflow-refresh-draft' + +describe('useWorkflowRefreshDraft', () => { + it('returns handleRefreshWorkflowDraft from hooks store', () => { + const handleRefreshWorkflowDraft = vi.fn() + + const { result } = renderWorkflowHook(() => useWorkflowRefreshDraft(), { + hooksStoreProps: { handleRefreshWorkflowDraft }, + }) + + expect(result.current.handleRefreshWorkflowDraft).toBe(handleRefreshWorkflowDraft) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-store-only.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-store-only.spec.ts deleted file mode 100644 index 2085e5ab47..0000000000 --- a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-store-only.spec.ts +++ /dev/null @@ -1,242 +0,0 @@ -import type { - AgentLogResponse, - HumanInputFormFilledResponse, - HumanInputFormTimeoutResponse, - TextChunkResponse, - TextReplaceResponse, - WorkflowFinishedResponse, -} from '@/types/workflow' -import { baseRunningData, renderWorkflowHook } from '../../__tests__/workflow-test-env' -import { WorkflowRunningStatus } from '../../types' -import { useWorkflowAgentLog } from '../use-workflow-run-event/use-workflow-agent-log' -import { useWorkflowFailed } from '../use-workflow-run-event/use-workflow-failed' -import { useWorkflowFinished } from '../use-workflow-run-event/use-workflow-finished' -import { useWorkflowNodeHumanInputFormFilled } from '../use-workflow-run-event/use-workflow-node-human-input-form-filled' -import { useWorkflowNodeHumanInputFormTimeout } from '../use-workflow-run-event/use-workflow-node-human-input-form-timeout' -import { useWorkflowPaused } from '../use-workflow-run-event/use-workflow-paused' -import { useWorkflowTextChunk } from '../use-workflow-run-event/use-workflow-text-chunk' -import { useWorkflowTextReplace } from '../use-workflow-run-event/use-workflow-text-replace' - -vi.mock('@/app/components/base/file-uploader/utils', () => ({ - getFilesInLogs: vi.fn(() => []), -})) - -describe('useWorkflowFailed', () => { - it('should set status to Failed', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowFailed(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - result.current.handleWorkflowFailed() - - expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Failed) - }) -}) - -describe('useWorkflowPaused', () => { - it('should set status to Paused', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowPaused(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - result.current.handleWorkflowPaused() - - expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Paused) - }) -}) - -describe('useWorkflowTextChunk', () => { - it('should append text and activate result tab', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowTextChunk(), { - initialStoreState: { - workflowRunningData: baseRunningData({ resultText: 'Hello' }), - }, - }) - - result.current.handleWorkflowTextChunk({ data: { text: ' World' } } as TextChunkResponse) - - const state = store.getState().workflowRunningData! - expect(state.resultText).toBe('Hello World') - expect(state.resultTabActive).toBe(true) - }) -}) - -describe('useWorkflowTextReplace', () => { - it('should replace resultText', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowTextReplace(), { - initialStoreState: { - workflowRunningData: baseRunningData({ resultText: 'old text' }), - }, - }) - - result.current.handleWorkflowTextReplace({ data: { text: 'new text' } } as TextReplaceResponse) - - expect(store.getState().workflowRunningData!.resultText).toBe('new text') - }) -}) - -describe('useWorkflowFinished', () => { - it('should merge data into result and activate result tab for single string output', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowFinished(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - result.current.handleWorkflowFinished({ - data: { status: 'succeeded', outputs: { answer: 'hello' } }, - } as WorkflowFinishedResponse) - - const state = store.getState().workflowRunningData! - expect(state.result.status).toBe('succeeded') - expect(state.resultTabActive).toBe(true) - expect(state.resultText).toBe('hello') - }) - - it('should not activate result tab for multi-key outputs', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowFinished(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - result.current.handleWorkflowFinished({ - data: { status: 'succeeded', outputs: { a: 'hello', b: 'world' } }, - } as WorkflowFinishedResponse) - - expect(store.getState().workflowRunningData!.resultTabActive).toBeFalsy() - }) -}) - -describe('useWorkflowAgentLog', () => { - it('should create agent_log array when execution_metadata has no agent_log', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ node_id: 'n1', execution_metadata: {} }], - }), - }, - }) - - result.current.handleWorkflowAgentLog({ - data: { node_id: 'n1', message_id: 'm1' }, - } as AgentLogResponse) - - const trace = store.getState().workflowRunningData!.tracing![0] - expect(trace.execution_metadata!.agent_log).toHaveLength(1) - expect(trace.execution_metadata!.agent_log![0].message_id).toBe('m1') - }) - - it('should append to existing agent_log', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ - node_id: 'n1', - execution_metadata: { agent_log: [{ message_id: 'm1', text: 'log1' }] }, - }], - }), - }, - }) - - result.current.handleWorkflowAgentLog({ - data: { node_id: 'n1', message_id: 'm2' }, - } as AgentLogResponse) - - expect(store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log).toHaveLength(2) - }) - - it('should update existing log entry by message_id', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ - node_id: 'n1', - execution_metadata: { agent_log: [{ message_id: 'm1', text: 'old' }] }, - }], - }), - }, - }) - - result.current.handleWorkflowAgentLog({ - data: { node_id: 'n1', message_id: 'm1', text: 'new' }, - } as unknown as AgentLogResponse) - - const log = store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log! - expect(log).toHaveLength(1) - expect((log[0] as unknown as { text: string }).text).toBe('new') - }) - - it('should create execution_metadata when it does not exist', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ node_id: 'n1' }], - }), - }, - }) - - result.current.handleWorkflowAgentLog({ - data: { node_id: 'n1', message_id: 'm1' }, - } as AgentLogResponse) - - expect(store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log).toHaveLength(1) - }) -}) - -describe('useWorkflowNodeHumanInputFormFilled', () => { - it('should remove form from humanInputFormDataList and add to humanInputFilledFormDataList', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormFilled(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - humanInputFormDataList: [ - { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, - ], - }), - }, - }) - - result.current.handleWorkflowNodeHumanInputFormFilled({ - data: { node_id: 'n1', node_title: 'Node 1', rendered_content: 'done' }, - } as HumanInputFormFilledResponse) - - const state = store.getState().workflowRunningData! - expect(state.humanInputFormDataList).toHaveLength(0) - expect(state.humanInputFilledFormDataList).toHaveLength(1) - expect(state.humanInputFilledFormDataList![0].node_id).toBe('n1') - }) - - it('should create humanInputFilledFormDataList when it does not exist', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormFilled(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - humanInputFormDataList: [ - { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, - ], - }), - }, - }) - - result.current.handleWorkflowNodeHumanInputFormFilled({ - data: { node_id: 'n1', node_title: 'Node 1', rendered_content: 'done' }, - } as HumanInputFormFilledResponse) - - expect(store.getState().workflowRunningData!.humanInputFilledFormDataList).toBeDefined() - }) -}) - -describe('useWorkflowNodeHumanInputFormTimeout', () => { - it('should set expiration_time on the matching form', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormTimeout(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - humanInputFormDataList: [ - { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '', expiration_time: 0 }, - ], - }), - }, - }) - - result.current.handleWorkflowNodeHumanInputFormTimeout({ - data: { node_id: 'n1', node_title: 'Node 1', expiration_time: 1000 }, - } as HumanInputFormTimeoutResponse) - - expect(store.getState().workflowRunningData!.humanInputFormDataList![0].expiration_time).toBe(1000) - }) -}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-store.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-store.spec.ts deleted file mode 100644 index 1c8a0764d1..0000000000 --- a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-store.spec.ts +++ /dev/null @@ -1,336 +0,0 @@ -import type { WorkflowRunningData } from '../../types' -import type { - IterationFinishedResponse, - IterationNextResponse, - LoopFinishedResponse, - LoopNextResponse, - NodeFinishedResponse, - WorkflowStartedResponse, -} from '@/types/workflow' -import { act, waitFor } from '@testing-library/react' -import { useEdges, useNodes } from 'reactflow' -import { createEdge, createNode } from '../../__tests__/fixtures' -import { baseRunningData, renderWorkflowFlowHook } from '../../__tests__/workflow-test-env' -import { DEFAULT_ITER_TIMES } from '../../constants' -import { NodeRunningStatus, WorkflowRunningStatus } from '../../types' -import { useWorkflowNodeFinished } from '../use-workflow-run-event/use-workflow-node-finished' -import { useWorkflowNodeIterationFinished } from '../use-workflow-run-event/use-workflow-node-iteration-finished' -import { useWorkflowNodeIterationNext } from '../use-workflow-run-event/use-workflow-node-iteration-next' -import { useWorkflowNodeLoopFinished } from '../use-workflow-run-event/use-workflow-node-loop-finished' -import { useWorkflowNodeLoopNext } from '../use-workflow-run-event/use-workflow-node-loop-next' -import { useWorkflowNodeRetry } from '../use-workflow-run-event/use-workflow-node-retry' -import { useWorkflowStarted } from '../use-workflow-run-event/use-workflow-started' - -type NodeRuntimeState = { - _waitingRun?: boolean - _runningStatus?: NodeRunningStatus - _retryIndex?: number - _iterationIndex?: number - _loopIndex?: number - _runningBranchId?: string -} - -type EdgeRuntimeState = { - _sourceRunningStatus?: NodeRunningStatus - _targetRunningStatus?: NodeRunningStatus - _waitingRun?: boolean -} - -const getNodeRuntimeState = (node?: { data?: unknown }): NodeRuntimeState => - (node?.data ?? {}) as NodeRuntimeState - -const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState => - (edge?.data ?? {}) as EdgeRuntimeState - -function createRunNodes() { - return [ - createNode({ - id: 'n1', - width: 200, - height: 80, - data: { _waitingRun: false }, - }), - ] -} - -function createRunEdges() { - return [ - createEdge({ - id: 'e1', - source: 'n0', - target: 'n1', - data: {}, - }), - ] -} - -function renderRunEventHook>( - useHook: () => T, - options?: { - nodes?: ReturnType - edges?: ReturnType - initialStoreState?: Record - }, -) { - const { nodes = createRunNodes(), edges = createRunEdges(), initialStoreState } = options ?? {} - - return renderWorkflowFlowHook(() => ({ - ...useHook(), - nodes: useNodes(), - edges: useEdges(), - }), { - nodes, - edges, - reactFlowProps: { fitView: false }, - initialStoreState, - }) -} - -describe('useWorkflowStarted', () => { - it('should initialize workflow running data and reset nodes/edges', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowStarted(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowStarted({ - task_id: 'task-2', - data: { id: 'run-1', workflow_id: 'wf-1', created_at: 1000 }, - } as WorkflowStartedResponse) - }) - - const state = store.getState().workflowRunningData! - expect(state.task_id).toBe('task-2') - expect(state.result.status).toBe(WorkflowRunningStatus.Running) - expect(state.resultText).toBe('') - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._waitingRun).toBe(true) - expect(getNodeRuntimeState(result.current.nodes[0])._runningBranchId).toBeUndefined() - expect(getEdgeRuntimeState(result.current.edges[0])._sourceRunningStatus).toBeUndefined() - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBeUndefined() - expect(getEdgeRuntimeState(result.current.edges[0])._waitingRun).toBe(true) - }) - }) - - it('should resume from Paused without resetting nodes/edges', () => { - const { result, store } = renderRunEventHook(() => useWorkflowStarted(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - result: { status: WorkflowRunningStatus.Paused } as WorkflowRunningData['result'], - }), - }, - }) - - act(() => { - result.current.handleWorkflowStarted({ - task_id: 'task-2', - data: { id: 'run-2', workflow_id: 'wf-1', created_at: 2000 }, - } as WorkflowStartedResponse) - }) - - expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Running) - expect(getNodeRuntimeState(result.current.nodes[0])._waitingRun).toBe(false) - expect(getEdgeRuntimeState(result.current.edges[0])._waitingRun).toBeUndefined() - }) -}) - -describe('useWorkflowNodeFinished', () => { - it('should update tracing and node running status', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowNodeFinished(), { - nodes: [ - createNode({ - id: 'n1', - data: { _runningStatus: NodeRunningStatus.Running }, - }), - ], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Running }], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeFinished({ - data: { id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, - } as NodeFinishedResponse) - }) - - const trace = store.getState().workflowRunningData!.tracing![0] - expect(trace.status).toBe(NodeRunningStatus.Succeeded) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) - }) - }) - - it('should set _runningBranchId for IfElse node', async () => { - const { result } = renderRunEventHook(() => useWorkflowNodeFinished(), { - nodes: [ - createNode({ - id: 'n1', - data: { _runningStatus: NodeRunningStatus.Running }, - }), - ], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Running }], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeFinished({ - data: { - id: 'trace-1', - node_id: 'n1', - node_type: 'if-else', - status: NodeRunningStatus.Succeeded, - outputs: { selected_case_id: 'branch-a' }, - }, - } as unknown as NodeFinishedResponse) - }) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._runningBranchId).toBe('branch-a') - }) - }) -}) - -describe('useWorkflowNodeRetry', () => { - it('should push retry data to tracing and update _retryIndex', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowNodeRetry(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowNodeRetry({ - data: { node_id: 'n1', retry_index: 2 }, - } as NodeFinishedResponse) - }) - - expect(store.getState().workflowRunningData!.tracing).toHaveLength(1) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._retryIndex).toBe(2) - }) - }) -}) - -describe('useWorkflowNodeIterationNext', () => { - it('should set _iterationIndex and increment iterTimes', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowNodeIterationNext(), { - initialStoreState: { - workflowRunningData: baseRunningData(), - iterTimes: 3, - }, - }) - - act(() => { - result.current.handleWorkflowNodeIterationNext({ - data: { node_id: 'n1' }, - } as IterationNextResponse) - }) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._iterationIndex).toBe(3) - }) - expect(store.getState().iterTimes).toBe(4) - }) -}) - -describe('useWorkflowNodeIterationFinished', () => { - it('should update tracing, reset iterTimes, update node status and edges', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowNodeIterationFinished(), { - nodes: [ - createNode({ - id: 'n1', - data: { _runningStatus: NodeRunningStatus.Running }, - }), - ], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ id: 'iter-1', node_id: 'n1', status: NodeRunningStatus.Running }], - }), - iterTimes: 10, - }, - }) - - act(() => { - result.current.handleWorkflowNodeIterationFinished({ - data: { id: 'iter-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, - } as IterationFinishedResponse) - }) - - expect(store.getState().iterTimes).toBe(DEFAULT_ITER_TIMES) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) - }) - }) -}) - -describe('useWorkflowNodeLoopNext', () => { - it('should set _loopIndex and reset child nodes to waiting', async () => { - const { result } = renderRunEventHook(() => useWorkflowNodeLoopNext(), { - nodes: [ - createNode({ id: 'n1', data: {} }), - createNode({ - id: 'n2', - position: { x: 300, y: 0 }, - parentId: 'n1', - data: { _waitingRun: false }, - }), - ], - edges: [], - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowNodeLoopNext({ - data: { node_id: 'n1', index: 5 }, - } as LoopNextResponse) - }) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n1'))._loopIndex).toBe(5) - expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n2'))._waitingRun).toBe(true) - expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n2'))._runningStatus).toBe(NodeRunningStatus.Waiting) - }) - }) -}) - -describe('useWorkflowNodeLoopFinished', () => { - it('should update tracing, node status and edges', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowNodeLoopFinished(), { - nodes: [ - createNode({ - id: 'n1', - data: { _runningStatus: NodeRunningStatus.Running }, - }), - ], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ id: 'loop-1', node_id: 'n1', status: NodeRunningStatus.Running }], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeLoopFinished({ - data: { id: 'loop-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, - } as LoopFinishedResponse) - }) - - const trace = store.getState().workflowRunningData!.tracing![0] - expect(trace.status).toBe(NodeRunningStatus.Succeeded) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) - }) - }) -}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-viewport.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-viewport.spec.ts deleted file mode 100644 index 73b16acf2e..0000000000 --- a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-viewport.spec.ts +++ /dev/null @@ -1,331 +0,0 @@ -import type { - HumanInputRequiredResponse, - IterationStartedResponse, - LoopStartedResponse, - NodeStartedResponse, -} from '@/types/workflow' -import { act, waitFor } from '@testing-library/react' -import { useEdges, useNodes, useStoreApi } from 'reactflow' -import { createEdge, createNode } from '../../__tests__/fixtures' -import { baseRunningData, renderWorkflowFlowHook } from '../../__tests__/workflow-test-env' -import { DEFAULT_ITER_TIMES } from '../../constants' -import { NodeRunningStatus } from '../../types' -import { useWorkflowNodeHumanInputRequired } from '../use-workflow-run-event/use-workflow-node-human-input-required' -import { useWorkflowNodeIterationStarted } from '../use-workflow-run-event/use-workflow-node-iteration-started' -import { useWorkflowNodeLoopStarted } from '../use-workflow-run-event/use-workflow-node-loop-started' -import { useWorkflowNodeStarted } from '../use-workflow-run-event/use-workflow-node-started' - -type NodeRuntimeState = { - _waitingRun?: boolean - _runningStatus?: NodeRunningStatus - _iterationLength?: number - _loopLength?: number -} - -type EdgeRuntimeState = { - _sourceRunningStatus?: NodeRunningStatus - _targetRunningStatus?: NodeRunningStatus - _waitingRun?: boolean -} - -const getNodeRuntimeState = (node?: { data?: unknown }): NodeRuntimeState => - (node?.data ?? {}) as NodeRuntimeState - -const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState => - (edge?.data ?? {}) as EdgeRuntimeState - -const containerParams = { clientWidth: 1200, clientHeight: 800 } - -function createViewportNodes() { - return [ - createNode({ - id: 'n0', - width: 200, - height: 80, - data: { _runningStatus: NodeRunningStatus.Succeeded }, - }), - createNode({ - id: 'n1', - position: { x: 100, y: 50 }, - width: 200, - height: 80, - data: { _waitingRun: true }, - }), - createNode({ - id: 'n2', - position: { x: 400, y: 50 }, - width: 200, - height: 80, - parentId: 'n1', - data: { _waitingRun: true }, - }), - ] -} - -function createViewportEdges() { - return [ - createEdge({ - id: 'e1', - source: 'n0', - target: 'n1', - sourceHandle: 'source', - data: {}, - }), - ] -} - -function renderViewportHook>( - useHook: () => T, - options?: { - nodes?: ReturnType - edges?: ReturnType - initialStoreState?: Record - }, -) { - const { - nodes = createViewportNodes(), - edges = createViewportEdges(), - initialStoreState, - } = options ?? {} - - return renderWorkflowFlowHook(() => ({ - ...useHook(), - nodes: useNodes(), - edges: useEdges(), - reactFlowStore: useStoreApi(), - }), { - nodes, - edges, - reactFlowProps: { fitView: false }, - initialStoreState, - }) -} - -describe('useWorkflowNodeStarted', () => { - it('should push to tracing, set node running, and adjust viewport for root node', async () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeStarted(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowNodeStarted( - { data: { node_id: 'n1' } } as NodeStartedResponse, - containerParams, - ) - }) - - const tracing = store.getState().workflowRunningData!.tracing! - expect(tracing).toHaveLength(1) - expect(tracing[0].status).toBe(NodeRunningStatus.Running) - - await waitFor(() => { - const transform = result.current.reactFlowStore.getState().transform - expect(transform[0]).toBe(200) - expect(transform[1]).toBe(310) - expect(transform[2]).toBe(1) - - const node = result.current.nodes.find(item => item.id === 'n1') - expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) - expect(getNodeRuntimeState(node)._waitingRun).toBe(false) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) - }) - }) - - it('should not adjust viewport for child node (has parentId)', async () => { - const { result } = renderViewportHook(() => useWorkflowNodeStarted(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowNodeStarted( - { data: { node_id: 'n2' } } as NodeStartedResponse, - containerParams, - ) - }) - - await waitFor(() => { - const transform = result.current.reactFlowStore.getState().transform - expect(transform[0]).toBe(0) - expect(transform[1]).toBe(0) - expect(transform[2]).toBe(1) - expect(getNodeRuntimeState(result.current.nodes.find(item => item.id === 'n2'))._runningStatus).toBe(NodeRunningStatus.Running) - }) - }) - - it('should update existing tracing entry if node_id exists at non-zero index', () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeStarted(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [ - { node_id: 'n0', status: NodeRunningStatus.Succeeded }, - { node_id: 'n1', status: NodeRunningStatus.Succeeded }, - ], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeStarted( - { data: { node_id: 'n1' } } as NodeStartedResponse, - containerParams, - ) - }) - - const tracing = store.getState().workflowRunningData!.tracing! - expect(tracing).toHaveLength(2) - expect(tracing[1].status).toBe(NodeRunningStatus.Running) - }) -}) - -describe('useWorkflowNodeIterationStarted', () => { - it('should push to tracing, reset iterTimes, set viewport, and update node with _iterationLength', async () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeIterationStarted(), { - nodes: createViewportNodes().slice(0, 2), - initialStoreState: { - workflowRunningData: baseRunningData(), - iterTimes: 99, - }, - }) - - act(() => { - result.current.handleWorkflowNodeIterationStarted( - { data: { node_id: 'n1', metadata: { iterator_length: 10 } } } as IterationStartedResponse, - containerParams, - ) - }) - - const tracing = store.getState().workflowRunningData!.tracing! - expect(tracing[0].status).toBe(NodeRunningStatus.Running) - expect(store.getState().iterTimes).toBe(DEFAULT_ITER_TIMES) - - await waitFor(() => { - const transform = result.current.reactFlowStore.getState().transform - expect(transform[0]).toBe(200) - expect(transform[1]).toBe(310) - expect(transform[2]).toBe(1) - - const node = result.current.nodes.find(item => item.id === 'n1') - expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) - expect(getNodeRuntimeState(node)._iterationLength).toBe(10) - expect(getNodeRuntimeState(node)._waitingRun).toBe(false) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) - }) - }) -}) - -describe('useWorkflowNodeLoopStarted', () => { - it('should push to tracing, set viewport, and update node with _loopLength', async () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeLoopStarted(), { - nodes: createViewportNodes().slice(0, 2), - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowNodeLoopStarted( - { data: { node_id: 'n1', metadata: { loop_length: 5 } } } as LoopStartedResponse, - containerParams, - ) - }) - - expect(store.getState().workflowRunningData!.tracing![0].status).toBe(NodeRunningStatus.Running) - - await waitFor(() => { - const transform = result.current.reactFlowStore.getState().transform - expect(transform[0]).toBe(200) - expect(transform[1]).toBe(310) - expect(transform[2]).toBe(1) - - const node = result.current.nodes.find(item => item.id === 'n1') - expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) - expect(getNodeRuntimeState(node)._loopLength).toBe(5) - expect(getNodeRuntimeState(node)._waitingRun).toBe(false) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) - }) - }) -}) - -describe('useWorkflowNodeHumanInputRequired', () => { - it('should create humanInputFormDataList and set tracing/node to Paused', async () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { - nodes: [ - createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), - createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), - ], - edges: [], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ node_id: 'n1', status: NodeRunningStatus.Running }], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeHumanInputRequired({ - data: { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: 'content' }, - } as HumanInputRequiredResponse) - }) - - const state = store.getState().workflowRunningData! - expect(state.humanInputFormDataList).toHaveLength(1) - expect(state.humanInputFormDataList![0].form_id).toBe('f1') - expect(state.tracing![0].status).toBe(NodeRunningStatus.Paused) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes.find(item => item.id === 'n1'))._runningStatus).toBe(NodeRunningStatus.Paused) - }) - }) - - it('should update existing form entry for same node_id', () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { - nodes: [ - createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), - createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), - ], - edges: [], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ node_id: 'n1', status: NodeRunningStatus.Running }], - humanInputFormDataList: [ - { node_id: 'n1', form_id: 'old', node_title: 'Node 1', form_content: 'old' }, - ], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeHumanInputRequired({ - data: { node_id: 'n1', form_id: 'new', node_title: 'Node 1', form_content: 'new' }, - } as HumanInputRequiredResponse) - }) - - const formList = store.getState().workflowRunningData!.humanInputFormDataList! - expect(formList).toHaveLength(1) - expect(formList[0].form_id).toBe('new') - }) - - it('should append new form entry for different node_id', () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { - nodes: [ - createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), - createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), - ], - edges: [], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ node_id: 'n2', status: NodeRunningStatus.Running }], - humanInputFormDataList: [ - { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, - ], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeHumanInputRequired({ - data: { node_id: 'n2', form_id: 'f2', node_title: 'Node 2', form_content: 'content2' }, - } as HumanInputRequiredResponse) - }) - - expect(store.getState().workflowRunningData!.humanInputFormDataList).toHaveLength(2) - }) -}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-run.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-run.spec.ts new file mode 100644 index 0000000000..ff8c64656e --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-run.spec.ts @@ -0,0 +1,24 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useWorkflowRun } from '../use-workflow-run' + +describe('useWorkflowRun', () => { + it('returns workflow run handlers from hooks store', () => { + const handlers = { + handleBackupDraft: vi.fn(), + handleLoadBackupDraft: vi.fn(), + handleRestoreFromPublishedWorkflow: vi.fn(), + handleRun: vi.fn(), + handleStopRun: vi.fn(), + } + + const { result } = renderWorkflowHook(() => useWorkflowRun(), { + hooksStoreProps: handlers, + }) + + expect(result.current.handleBackupDraft).toBe(handlers.handleBackupDraft) + expect(result.current.handleLoadBackupDraft).toBe(handlers.handleLoadBackupDraft) + expect(result.current.handleRestoreFromPublishedWorkflow).toBe(handlers.handleRestoreFromPublishedWorkflow) + expect(result.current.handleRun).toBe(handlers.handleRun) + expect(result.current.handleStopRun).toBe(handlers.handleStopRun) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-search.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-search.spec.tsx new file mode 100644 index 0000000000..4e9f4c9b45 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-search.spec.tsx @@ -0,0 +1,119 @@ +import type { CommonNodeType, Node, ToolWithProvider } from '../../types' +import { act, renderHook } from '@testing-library/react' +import { workflowNodesAction } from '@/app/components/goto-anything/actions/workflow-nodes' +import { CollectionType } from '@/app/components/tools/types' +import { BlockEnum } from '../../types' +import { useWorkflowSearch } from '../use-workflow-search' + +const mockHandleNodeSelect = vi.hoisted(() => vi.fn()) +const runtimeNodes = vi.hoisted(() => [] as Node[]) + +vi.mock('reactflow', () => ({ + useNodes: () => runtimeNodes, +})) + +vi.mock('../use-nodes-interactions', () => ({ + useNodesInteractions: () => ({ + handleNodeSelect: mockHandleNodeSelect, + }), +})) + +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => ({ + data: [{ + id: 'provider-1', + icon: 'tool-icon', + tools: [], + }] satisfies Partial[], + }), + useAllCustomTools: () => ({ data: [] }), + useAllWorkflowTools: () => ({ data: [] }), + useAllMCPTools: () => ({ data: [] }), +})) + +const createNode = (overrides: Partial = {}): Node => ({ + id: 'node-1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.LLM, + title: 'Writer', + desc: 'Draft content', + } as CommonNodeType, + ...overrides, +}) + +describe('useWorkflowSearch', () => { + beforeEach(() => { + vi.clearAllMocks() + runtimeNodes.length = 0 + workflowNodesAction.searchFn = undefined + }) + + it('registers workflow node search results with tool icons and llm metadata scoring', async () => { + runtimeNodes.push( + createNode({ + id: 'llm-1', + data: { + type: BlockEnum.LLM, + title: 'Writer', + desc: 'Draft content', + model: { + provider: 'openai', + name: 'gpt-4o', + mode: 'chat', + }, + } as CommonNodeType, + }), + createNode({ + id: 'tool-1', + data: { + type: BlockEnum.Tool, + title: 'Google Search', + desc: 'Search the web', + provider_type: CollectionType.builtIn, + provider_id: 'provider-1', + } as CommonNodeType, + }), + createNode({ + id: 'internal-start', + data: { + type: BlockEnum.IterationStart, + title: 'Internal Start', + desc: '', + } as CommonNodeType, + }), + ) + + const { unmount } = renderHook(() => useWorkflowSearch()) + + const llmResults = await workflowNodesAction.search('', 'gpt') + expect(llmResults.map(item => item.id)).toEqual(['llm-1']) + expect(llmResults[0]?.title).toBe('Writer') + + const toolResults = await workflowNodesAction.search('', 'search') + expect(toolResults.map(item => item.id)).toEqual(['tool-1']) + expect(toolResults[0]?.description).toBe('Search the web') + + unmount() + + expect(workflowNodesAction.searchFn).toBeUndefined() + }) + + it('binds the node selection listener to handleNodeSelect', () => { + const { unmount } = renderHook(() => useWorkflowSearch()) + + act(() => { + document.dispatchEvent(new CustomEvent('workflow:select-node', { + detail: { + nodeId: 'node-42', + focus: false, + }, + })) + }) + + expect(mockHandleNodeSelect).toHaveBeenCalledWith('node-42') + + unmount() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-start-run.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-start-run.spec.tsx new file mode 100644 index 0000000000..fdde912285 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-start-run.spec.tsx @@ -0,0 +1,28 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useWorkflowStartRun } from '../use-workflow-start-run' + +describe('useWorkflowStartRun', () => { + it('returns start-run handlers from hooks store', () => { + const handlers = { + handleStartWorkflowRun: vi.fn(), + handleWorkflowStartRunInWorkflow: vi.fn(), + handleWorkflowStartRunInChatflow: vi.fn(), + handleWorkflowTriggerScheduleRunInWorkflow: vi.fn(), + handleWorkflowTriggerWebhookRunInWorkflow: vi.fn(), + handleWorkflowTriggerPluginRunInWorkflow: vi.fn(), + handleWorkflowRunAllTriggersInWorkflow: vi.fn(), + } + + const { result } = renderWorkflowHook(() => useWorkflowStartRun(), { + hooksStoreProps: handlers, + }) + + expect(result.current.handleStartWorkflowRun).toBe(handlers.handleStartWorkflowRun) + expect(result.current.handleWorkflowStartRunInWorkflow).toBe(handlers.handleWorkflowStartRunInWorkflow) + expect(result.current.handleWorkflowStartRunInChatflow).toBe(handlers.handleWorkflowStartRunInChatflow) + expect(result.current.handleWorkflowTriggerScheduleRunInWorkflow).toBe(handlers.handleWorkflowTriggerScheduleRunInWorkflow) + expect(result.current.handleWorkflowTriggerWebhookRunInWorkflow).toBe(handlers.handleWorkflowTriggerWebhookRunInWorkflow) + expect(result.current.handleWorkflowTriggerPluginRunInWorkflow).toBe(handlers.handleWorkflowTriggerPluginRunInWorkflow) + expect(result.current.handleWorkflowRunAllTriggersInWorkflow).toBe(handlers.handleWorkflowRunAllTriggersInWorkflow) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-update.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-update.spec.tsx new file mode 100644 index 0000000000..8bd2a1c4f3 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-update.spec.tsx @@ -0,0 +1,66 @@ +import { act } from '@testing-library/react' +import { createNode } from '../../__tests__/fixtures' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useWorkflowUpdate } from '../use-workflow-update' + +const mockSetViewport = vi.hoisted(() => vi.fn()) +const mockEventEmit = vi.hoisted(() => vi.fn()) +const mockInitialNodes = vi.hoisted(() => vi.fn((nodes: unknown[], _edges: unknown[]) => nodes)) +const mockInitialEdges = vi.hoisted(() => vi.fn((edges: unknown[], _nodes: unknown[]) => edges)) + +vi.mock('reactflow', () => ({ + Position: { + Left: 'left', + Right: 'right', + Top: 'top', + Bottom: 'bottom', + }, + useReactFlow: () => ({ + setViewport: mockSetViewport, + }), +})) + +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + emit: (...args: unknown[]) => mockEventEmit(...args), + }, + }), +})) + +vi.mock('../../utils', async importOriginal => ({ + ...(await importOriginal()), + initialNodes: (nodes: unknown[], edges: unknown[]) => mockInitialNodes(nodes, edges), + initialEdges: (edges: unknown[], nodes: unknown[]) => mockInitialEdges(edges, nodes), +})) + +describe('useWorkflowUpdate', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('emits initialized data and only sets a valid viewport', () => { + const { result } = renderWorkflowHook(() => useWorkflowUpdate()) + + act(() => { + result.current.handleUpdateWorkflowCanvas({ + nodes: [createNode({ id: 'n1' })], + edges: [], + viewport: { x: 10, y: 20, zoom: 0.5 }, + } as never) + result.current.handleUpdateWorkflowCanvas({ + nodes: [], + edges: [], + viewport: { x: 'bad' } as never, + }) + }) + + expect(mockInitialNodes).toHaveBeenCalled() + expect(mockInitialEdges).toHaveBeenCalled() + expect(mockEventEmit).toHaveBeenCalledWith(expect.objectContaining({ + type: 'WORKFLOW_DATA_UPDATE', + })) + expect(mockSetViewport).toHaveBeenCalledTimes(1) + expect(mockSetViewport).toHaveBeenCalledWith({ x: 10, y: 20, zoom: 0.5 }) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-zoom.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-zoom.spec.ts new file mode 100644 index 0000000000..83bc1b27ad --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-zoom.spec.ts @@ -0,0 +1,86 @@ +import { act, renderHook } from '@testing-library/react' +import { useWorkflowZoom } from '../use-workflow-zoom' + +const { + mockFitView, + mockZoomIn, + mockZoomOut, + mockZoomTo, + mockHandleSyncWorkflowDraft, + runtimeState, +} = vi.hoisted(() => ({ + mockFitView: vi.fn(), + mockZoomIn: vi.fn(), + mockZoomOut: vi.fn(), + mockZoomTo: vi.fn(), + mockHandleSyncWorkflowDraft: vi.fn(), + runtimeState: { + workflowReadOnly: false, + }, +})) + +vi.mock('reactflow', () => ({ + useReactFlow: () => ({ + fitView: mockFitView, + zoomIn: mockZoomIn, + zoomOut: mockZoomOut, + zoomTo: mockZoomTo, + }), +})) + +vi.mock('../use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ + handleSyncWorkflowDraft: (...args: unknown[]) => mockHandleSyncWorkflowDraft(...args), + }), +})) + +vi.mock('../use-workflow', () => ({ + useWorkflowReadOnly: () => ({ + getWorkflowReadOnly: () => runtimeState.workflowReadOnly, + }), +})) + +describe('useWorkflowZoom', () => { + beforeEach(() => { + vi.clearAllMocks() + runtimeState.workflowReadOnly = false + }) + + it('runs zoom actions and syncs the workflow draft when editable', () => { + const { result } = renderHook(() => useWorkflowZoom()) + + act(() => { + result.current.handleFitView() + result.current.handleBackToOriginalSize() + result.current.handleSizeToHalf() + result.current.handleZoomOut() + result.current.handleZoomIn() + }) + + expect(mockFitView).toHaveBeenCalledTimes(1) + expect(mockZoomTo).toHaveBeenCalledWith(1) + expect(mockZoomTo).toHaveBeenCalledWith(0.5) + expect(mockZoomOut).toHaveBeenCalledTimes(1) + expect(mockZoomIn).toHaveBeenCalledTimes(1) + expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledTimes(5) + }) + + it('blocks zoom actions when the workflow is read-only', () => { + runtimeState.workflowReadOnly = true + const { result } = renderHook(() => useWorkflowZoom()) + + act(() => { + result.current.handleFitView() + result.current.handleBackToOriginalSize() + result.current.handleSizeToHalf() + result.current.handleZoomOut() + result.current.handleZoomIn() + }) + + expect(mockFitView).not.toHaveBeenCalled() + expect(mockZoomTo).not.toHaveBeenCalled() + expect(mockZoomOut).not.toHaveBeenCalled() + expect(mockZoomIn).not.toHaveBeenCalled() + expect(mockHandleSyncWorkflowDraft).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/test-helpers.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/test-helpers.ts new file mode 100644 index 0000000000..8c2ed18f19 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/test-helpers.ts @@ -0,0 +1,186 @@ +import type { WorkflowRunningData } from '../../../types' +import type { + IterationFinishedResponse, + IterationNextResponse, + LoopFinishedResponse, + LoopNextResponse, + NodeFinishedResponse, + NodeStartedResponse, + WorkflowStartedResponse, +} from '@/types/workflow' +import { useEdges, useNodes, useStoreApi } from 'reactflow' +import { createEdge, createNode } from '../../../__tests__/fixtures' +import { renderWorkflowFlowHook } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus, WorkflowRunningStatus } from '../../../types' + +type NodeRuntimeState = { + _waitingRun?: boolean + _runningStatus?: NodeRunningStatus + _retryIndex?: number + _iterationIndex?: number + _iterationLength?: number + _loopIndex?: number + _loopLength?: number + _runningBranchId?: string +} + +type EdgeRuntimeState = { + _sourceRunningStatus?: NodeRunningStatus + _targetRunningStatus?: NodeRunningStatus + _waitingRun?: boolean +} + +export const getNodeRuntimeState = (node?: { data?: unknown }): NodeRuntimeState => + (node?.data ?? {}) as NodeRuntimeState + +export const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState => + (edge?.data ?? {}) as EdgeRuntimeState + +function createRunNodes() { + return [ + createNode({ + id: 'n1', + width: 200, + height: 80, + data: { _waitingRun: false }, + }), + ] +} + +function createRunEdges() { + return [ + createEdge({ + id: 'e1', + source: 'n0', + target: 'n1', + data: {}, + }), + ] +} + +export function createViewportNodes() { + return [ + createNode({ + id: 'n0', + width: 200, + height: 80, + data: { _runningStatus: NodeRunningStatus.Succeeded }, + }), + createNode({ + id: 'n1', + position: { x: 100, y: 50 }, + width: 200, + height: 80, + data: { _waitingRun: true }, + }), + createNode({ + id: 'n2', + position: { x: 400, y: 50 }, + width: 200, + height: 80, + parentId: 'n1', + data: { _waitingRun: true }, + }), + ] +} + +function createViewportEdges() { + return [ + createEdge({ + id: 'e1', + source: 'n0', + target: 'n1', + sourceHandle: 'source', + data: {}, + }), + ] +} + +export const containerParams = { clientWidth: 1200, clientHeight: 800 } + +export function renderRunEventHook>( + useHook: () => T, + options?: { + nodes?: ReturnType + edges?: ReturnType + initialStoreState?: Record + }, +) { + const { nodes = createRunNodes(), edges = createRunEdges(), initialStoreState } = options ?? {} + + return renderWorkflowFlowHook(() => ({ + ...useHook(), + nodes: useNodes(), + edges: useEdges(), + }), { + nodes, + edges, + reactFlowProps: { fitView: false }, + initialStoreState, + }) +} + +export function renderViewportHook>( + useHook: () => T, + options?: { + nodes?: ReturnType + edges?: ReturnType + initialStoreState?: Record + }, +) { + const { + nodes = createViewportNodes(), + edges = createViewportEdges(), + initialStoreState, + } = options ?? {} + + return renderWorkflowFlowHook(() => ({ + ...useHook(), + nodes: useNodes(), + edges: useEdges(), + reactFlowStore: useStoreApi(), + }), { + nodes, + edges, + reactFlowProps: { fitView: false }, + initialStoreState, + }) +} + +export const createStartedResponse = (overrides: Partial = {}): WorkflowStartedResponse => ({ + task_id: 'task-2', + data: { id: 'run-1', workflow_id: 'wf-1', created_at: 1000 }, + ...overrides, +} as WorkflowStartedResponse) + +export const createNodeFinishedResponse = (overrides: Partial = {}): NodeFinishedResponse => ({ + data: { id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, + ...overrides, +} as NodeFinishedResponse) + +export const createIterationNextResponse = (overrides: Partial = {}): IterationNextResponse => ({ + data: { node_id: 'n1' }, + ...overrides, +} as IterationNextResponse) + +export const createIterationFinishedResponse = (overrides: Partial = {}): IterationFinishedResponse => ({ + data: { id: 'iter-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, + ...overrides, +} as IterationFinishedResponse) + +export const createLoopNextResponse = (overrides: Partial = {}): LoopNextResponse => ({ + data: { node_id: 'n1', index: 5 }, + ...overrides, +} as LoopNextResponse) + +export const createLoopFinishedResponse = (overrides: Partial = {}): LoopFinishedResponse => ({ + data: { id: 'loop-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, + ...overrides, +} as LoopFinishedResponse) + +export const createNodeStartedResponse = (overrides: Partial = {}): NodeStartedResponse => ({ + data: { node_id: 'n1' }, + ...overrides, +} as NodeStartedResponse) + +export const pausedRunningData = (): WorkflowRunningData['result'] => ({ status: WorkflowRunningStatus.Paused } as WorkflowRunningData['result']) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-agent-log.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-agent-log.spec.ts new file mode 100644 index 0000000000..cabfc0f6d1 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-agent-log.spec.ts @@ -0,0 +1,83 @@ +import type { AgentLogResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowAgentLog } from '../use-workflow-agent-log' + +vi.mock('@/app/components/base/file-uploader/utils', () => ({ + getFilesInLogs: vi.fn(() => []), +})) + +describe('useWorkflowAgentLog', () => { + it('creates agent_log when execution_metadata has none', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ node_id: 'n1', execution_metadata: {} }], + }), + }, + }) + + result.current.handleWorkflowAgentLog({ + data: { node_id: 'n1', message_id: 'm1' }, + } as AgentLogResponse) + + const trace = store.getState().workflowRunningData!.tracing![0] + expect(trace.execution_metadata!.agent_log).toHaveLength(1) + expect(trace.execution_metadata!.agent_log![0].message_id).toBe('m1') + }) + + it('appends to existing agent_log', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ + node_id: 'n1', + execution_metadata: { agent_log: [{ message_id: 'm1', text: 'log1' }] }, + }], + }), + }, + }) + + result.current.handleWorkflowAgentLog({ + data: { node_id: 'n1', message_id: 'm2' }, + } as AgentLogResponse) + + expect(store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log).toHaveLength(2) + }) + + it('updates an existing log entry by message_id', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ + node_id: 'n1', + execution_metadata: { agent_log: [{ message_id: 'm1', text: 'old' }] }, + }], + }), + }, + }) + + result.current.handleWorkflowAgentLog({ + data: { node_id: 'n1', message_id: 'm1', text: 'new' }, + } as unknown as AgentLogResponse) + + const log = store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log! + expect(log).toHaveLength(1) + expect((log[0] as unknown as { text: string }).text).toBe('new') + }) + + it('creates execution_metadata when it does not exist', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ node_id: 'n1' }], + }), + }, + }) + + result.current.handleWorkflowAgentLog({ + data: { node_id: 'n1', message_id: 'm1' }, + } as AgentLogResponse) + + expect(store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log).toHaveLength(1) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-failed.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-failed.spec.ts new file mode 100644 index 0000000000..53ee281f7e --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-failed.spec.ts @@ -0,0 +1,15 @@ +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { WorkflowRunningStatus } from '../../../types' +import { useWorkflowFailed } from '../use-workflow-failed' + +describe('useWorkflowFailed', () => { + it('sets status to Failed', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowFailed(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + result.current.handleWorkflowFailed() + + expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Failed) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-finished.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-finished.spec.ts new file mode 100644 index 0000000000..910b64ed18 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-finished.spec.ts @@ -0,0 +1,32 @@ +import type { WorkflowFinishedResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowFinished } from '../use-workflow-finished' + +describe('useWorkflowFinished', () => { + it('merges data into result and activates result tab for single string output', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowFinished(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + result.current.handleWorkflowFinished({ + data: { status: 'succeeded', outputs: { answer: 'hello' } }, + } as WorkflowFinishedResponse) + + const state = store.getState().workflowRunningData! + expect(state.result.status).toBe('succeeded') + expect(state.resultTabActive).toBe(true) + expect(state.resultText).toBe('hello') + }) + + it('does not activate the result tab for multi-key outputs', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowFinished(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + result.current.handleWorkflowFinished({ + data: { status: 'succeeded', outputs: { a: 'hello', b: 'world' } }, + } as WorkflowFinishedResponse) + + expect(store.getState().workflowRunningData!.resultTabActive).toBeFalsy() + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-finished.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-finished.spec.ts new file mode 100644 index 0000000000..efcdc15d88 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-finished.spec.ts @@ -0,0 +1,73 @@ +import { act, waitFor } from '@testing-library/react' +import { createNode } from '../../../__tests__/fixtures' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { BlockEnum, NodeRunningStatus } from '../../../types' +import { useWorkflowNodeFinished } from '../use-workflow-node-finished' +import { + createNodeFinishedResponse, + getEdgeRuntimeState, + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeFinished', () => { + it('updates tracing and node running status', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowNodeFinished(), { + nodes: [ + createNode({ + id: 'n1', + data: { _runningStatus: NodeRunningStatus.Running }, + }), + ], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Running }], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeFinished(createNodeFinishedResponse()) + }) + + const trace = store.getState().workflowRunningData!.tracing![0] + expect(trace.status).toBe(NodeRunningStatus.Succeeded) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) + }) + }) + + it('sets _runningBranchId for IfElse nodes', async () => { + const { result } = renderRunEventHook(() => useWorkflowNodeFinished(), { + nodes: [ + createNode({ + id: 'n1', + data: { _runningStatus: NodeRunningStatus.Running }, + }), + ], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Running }], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeFinished(createNodeFinishedResponse({ + data: { + id: 'trace-1', + node_id: 'n1', + node_type: BlockEnum.IfElse, + status: NodeRunningStatus.Succeeded, + outputs: { selected_case_id: 'branch-a' }, + } as never, + })) + }) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._runningBranchId).toBe('branch-a') + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-filled.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-filled.spec.ts new file mode 100644 index 0000000000..aa8e89327b --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-filled.spec.ts @@ -0,0 +1,44 @@ +import type { HumanInputFormFilledResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowNodeHumanInputFormFilled } from '../use-workflow-node-human-input-form-filled' + +describe('useWorkflowNodeHumanInputFormFilled', () => { + it('removes the form from pending and adds it to filled', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormFilled(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + humanInputFormDataList: [ + { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, + ], + }), + }, + }) + + result.current.handleWorkflowNodeHumanInputFormFilled({ + data: { node_id: 'n1', node_title: 'Node 1', rendered_content: 'done' }, + } as HumanInputFormFilledResponse) + + const state = store.getState().workflowRunningData! + expect(state.humanInputFormDataList).toHaveLength(0) + expect(state.humanInputFilledFormDataList).toHaveLength(1) + expect(state.humanInputFilledFormDataList![0].node_id).toBe('n1') + }) + + it('creates humanInputFilledFormDataList when it does not exist', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormFilled(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + humanInputFormDataList: [ + { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, + ], + }), + }, + }) + + result.current.handleWorkflowNodeHumanInputFormFilled({ + data: { node_id: 'n1', node_title: 'Node 1', rendered_content: 'done' }, + } as HumanInputFormFilledResponse) + + expect(store.getState().workflowRunningData!.humanInputFilledFormDataList).toBeDefined() + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-timeout.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-timeout.spec.ts new file mode 100644 index 0000000000..e528b49846 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-timeout.spec.ts @@ -0,0 +1,23 @@ +import type { HumanInputFormTimeoutResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowNodeHumanInputFormTimeout } from '../use-workflow-node-human-input-form-timeout' + +describe('useWorkflowNodeHumanInputFormTimeout', () => { + it('sets expiration_time on the matching form', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormTimeout(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + humanInputFormDataList: [ + { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '', expiration_time: 0 }, + ], + }), + }, + }) + + result.current.handleWorkflowNodeHumanInputFormTimeout({ + data: { node_id: 'n1', node_title: 'Node 1', expiration_time: 1000 }, + } as HumanInputFormTimeoutResponse) + + expect(store.getState().workflowRunningData!.humanInputFormDataList![0].expiration_time).toBe(1000) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-required.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-required.spec.ts new file mode 100644 index 0000000000..23fdf8a3c3 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-required.spec.ts @@ -0,0 +1,96 @@ +import type { HumanInputRequiredResponse } from '@/types/workflow' +import { act, waitFor } from '@testing-library/react' +import { createNode } from '../../../__tests__/fixtures' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeHumanInputRequired } from '../use-workflow-node-human-input-required' +import { + getNodeRuntimeState, + renderViewportHook, +} from './test-helpers' + +describe('useWorkflowNodeHumanInputRequired', () => { + it('creates humanInputFormDataList and sets tracing and node to Paused', async () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { + nodes: [ + createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), + createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), + ], + edges: [], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ node_id: 'n1', status: NodeRunningStatus.Running }], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeHumanInputRequired({ + data: { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: 'content' }, + } as HumanInputRequiredResponse) + }) + + const state = store.getState().workflowRunningData! + expect(state.humanInputFormDataList).toHaveLength(1) + expect(state.humanInputFormDataList![0].form_id).toBe('f1') + expect(state.tracing![0].status).toBe(NodeRunningStatus.Paused) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes.find(item => item.id === 'n1'))._runningStatus).toBe(NodeRunningStatus.Paused) + }) + }) + + it('updates existing form entry for the same node_id', () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { + nodes: [ + createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), + createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), + ], + edges: [], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ node_id: 'n1', status: NodeRunningStatus.Running }], + humanInputFormDataList: [ + { node_id: 'n1', form_id: 'old', node_title: 'Node 1', form_content: 'old' }, + ], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeHumanInputRequired({ + data: { node_id: 'n1', form_id: 'new', node_title: 'Node 1', form_content: 'new' }, + } as HumanInputRequiredResponse) + }) + + const formList = store.getState().workflowRunningData!.humanInputFormDataList! + expect(formList).toHaveLength(1) + expect(formList[0].form_id).toBe('new') + }) + + it('appends a new form entry for a different node_id', () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { + nodes: [ + createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), + createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), + ], + edges: [], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ node_id: 'n2', status: NodeRunningStatus.Running }], + humanInputFormDataList: [ + { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, + ], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeHumanInputRequired({ + data: { node_id: 'n2', form_id: 'f2', node_title: 'Node 2', form_content: 'content2' }, + } as HumanInputRequiredResponse) + }) + + expect(store.getState().workflowRunningData!.humanInputFormDataList).toHaveLength(2) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-finished.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-finished.spec.ts new file mode 100644 index 0000000000..87617f0835 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-finished.spec.ts @@ -0,0 +1,42 @@ +import { act, waitFor } from '@testing-library/react' +import { createNode } from '../../../__tests__/fixtures' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { DEFAULT_ITER_TIMES } from '../../../constants' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeIterationFinished } from '../use-workflow-node-iteration-finished' +import { + createIterationFinishedResponse, + getEdgeRuntimeState, + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeIterationFinished', () => { + it('updates tracing, resets iterTimes, updates node status and edges', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowNodeIterationFinished(), { + nodes: [ + createNode({ + id: 'n1', + data: { _runningStatus: NodeRunningStatus.Running }, + }), + ], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ id: 'iter-1', node_id: 'n1', status: NodeRunningStatus.Running }], + }), + iterTimes: 10, + }, + }) + + act(() => { + result.current.handleWorkflowNodeIterationFinished(createIterationFinishedResponse()) + }) + + expect(store.getState().iterTimes).toBe(DEFAULT_ITER_TIMES) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-next.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-next.spec.ts new file mode 100644 index 0000000000..ac5f2f02ea --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-next.spec.ts @@ -0,0 +1,28 @@ +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { useWorkflowNodeIterationNext } from '../use-workflow-node-iteration-next' +import { + createIterationNextResponse, + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeIterationNext', () => { + it('sets _iterationIndex and increments iterTimes', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowNodeIterationNext(), { + initialStoreState: { + workflowRunningData: baseRunningData(), + iterTimes: 3, + }, + }) + + act(() => { + result.current.handleWorkflowNodeIterationNext(createIterationNextResponse()) + }) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._iterationIndex).toBe(3) + }) + expect(store.getState().iterTimes).toBe(4) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-started.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-started.spec.ts new file mode 100644 index 0000000000..ccff1b288b --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-started.spec.ts @@ -0,0 +1,49 @@ +import type { IterationStartedResponse } from '@/types/workflow' +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { DEFAULT_ITER_TIMES } from '../../../constants' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeIterationStarted } from '../use-workflow-node-iteration-started' +import { + containerParams, + createViewportNodes, + getEdgeRuntimeState, + getNodeRuntimeState, + renderViewportHook, +} from './test-helpers' + +describe('useWorkflowNodeIterationStarted', () => { + it('pushes to tracing, resets iterTimes, sets viewport, and updates node with _iterationLength', async () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeIterationStarted(), { + nodes: createViewportNodes().slice(0, 2), + initialStoreState: { + workflowRunningData: baseRunningData(), + iterTimes: 99, + }, + }) + + act(() => { + result.current.handleWorkflowNodeIterationStarted( + { data: { node_id: 'n1', metadata: { iterator_length: 10 } } } as IterationStartedResponse, + containerParams, + ) + }) + + const tracing = store.getState().workflowRunningData!.tracing! + expect(tracing[0].status).toBe(NodeRunningStatus.Running) + expect(store.getState().iterTimes).toBe(DEFAULT_ITER_TIMES) + + await waitFor(() => { + const transform = result.current.reactFlowStore.getState().transform + expect(transform[0]).toBe(200) + expect(transform[1]).toBe(310) + expect(transform[2]).toBe(1) + + const node = result.current.nodes.find(item => item.id === 'n1') + expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) + expect(getNodeRuntimeState(node)._iterationLength).toBe(10) + expect(getNodeRuntimeState(node)._waitingRun).toBe(false) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-finished.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-finished.spec.ts new file mode 100644 index 0000000000..7acd9897ed --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-finished.spec.ts @@ -0,0 +1,40 @@ +import { act, waitFor } from '@testing-library/react' +import { createNode } from '../../../__tests__/fixtures' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeLoopFinished } from '../use-workflow-node-loop-finished' +import { + createLoopFinishedResponse, + getEdgeRuntimeState, + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeLoopFinished', () => { + it('updates tracing, node status and edges', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowNodeLoopFinished(), { + nodes: [ + createNode({ + id: 'n1', + data: { _runningStatus: NodeRunningStatus.Running }, + }), + ], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ id: 'loop-1', node_id: 'n1', status: NodeRunningStatus.Running }], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeLoopFinished(createLoopFinishedResponse()) + }) + + expect(store.getState().workflowRunningData!.tracing![0].status).toBe(NodeRunningStatus.Succeeded) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-next.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-next.spec.ts new file mode 100644 index 0000000000..5baa44c983 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-next.spec.ts @@ -0,0 +1,38 @@ +import { act, waitFor } from '@testing-library/react' +import { createNode } from '../../../__tests__/fixtures' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeLoopNext } from '../use-workflow-node-loop-next' +import { + createLoopNextResponse, + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeLoopNext', () => { + it('sets _loopIndex and resets child nodes to waiting', async () => { + const { result } = renderRunEventHook(() => useWorkflowNodeLoopNext(), { + nodes: [ + createNode({ id: 'n1', data: {} }), + createNode({ + id: 'n2', + position: { x: 300, y: 0 }, + parentId: 'n1', + data: { _waitingRun: false }, + }), + ], + edges: [], + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowNodeLoopNext(createLoopNextResponse()) + }) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n1'))._loopIndex).toBe(5) + expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n2'))._waitingRun).toBe(true) + expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n2'))._runningStatus).toBe(NodeRunningStatus.Waiting) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-started.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-started.spec.ts new file mode 100644 index 0000000000..b0e8bf2cc5 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-started.spec.ts @@ -0,0 +1,43 @@ +import type { LoopStartedResponse } from '@/types/workflow' +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeLoopStarted } from '../use-workflow-node-loop-started' +import { + containerParams, + createViewportNodes, + getEdgeRuntimeState, + getNodeRuntimeState, + renderViewportHook, +} from './test-helpers' + +describe('useWorkflowNodeLoopStarted', () => { + it('pushes to tracing, sets viewport, and updates node with _loopLength', async () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeLoopStarted(), { + nodes: createViewportNodes().slice(0, 2), + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowNodeLoopStarted( + { data: { node_id: 'n1', metadata: { loop_length: 5 } } } as LoopStartedResponse, + containerParams, + ) + }) + + expect(store.getState().workflowRunningData!.tracing![0].status).toBe(NodeRunningStatus.Running) + + await waitFor(() => { + const transform = result.current.reactFlowStore.getState().transform + expect(transform[0]).toBe(200) + expect(transform[1]).toBe(310) + expect(transform[2]).toBe(1) + + const node = result.current.nodes.find(item => item.id === 'n1') + expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) + expect(getNodeRuntimeState(node)._loopLength).toBe(5) + expect(getNodeRuntimeState(node)._waitingRun).toBe(false) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-retry.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-retry.spec.ts new file mode 100644 index 0000000000..b3c6b814b1 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-retry.spec.ts @@ -0,0 +1,27 @@ +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { useWorkflowNodeRetry } from '../use-workflow-node-retry' +import { + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeRetry', () => { + it('pushes retry data to tracing and updates _retryIndex', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowNodeRetry(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowNodeRetry({ + data: { node_id: 'n1', retry_index: 2 }, + } as never) + }) + + expect(store.getState().workflowRunningData!.tracing).toHaveLength(1) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._retryIndex).toBe(2) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-started.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-started.spec.ts new file mode 100644 index 0000000000..a8a52e0a84 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-started.spec.ts @@ -0,0 +1,80 @@ +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeStarted } from '../use-workflow-node-started' +import { + containerParams, + createNodeStartedResponse, + getEdgeRuntimeState, + getNodeRuntimeState, + renderViewportHook, +} from './test-helpers' + +describe('useWorkflowNodeStarted', () => { + it('pushes to tracing, sets node running, and adjusts viewport for root node', async () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeStarted(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowNodeStarted(createNodeStartedResponse(), containerParams) + }) + + const tracing = store.getState().workflowRunningData!.tracing! + expect(tracing).toHaveLength(1) + expect(tracing[0].status).toBe(NodeRunningStatus.Running) + + await waitFor(() => { + const transform = result.current.reactFlowStore.getState().transform + expect(transform[0]).toBe(200) + expect(transform[1]).toBe(310) + expect(transform[2]).toBe(1) + + const node = result.current.nodes.find(item => item.id === 'n1') + expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) + expect(getNodeRuntimeState(node)._waitingRun).toBe(false) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) + }) + }) + + it('does not adjust viewport for child nodes', async () => { + const { result } = renderViewportHook(() => useWorkflowNodeStarted(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowNodeStarted(createNodeStartedResponse({ + data: { node_id: 'n2' } as never, + }), containerParams) + }) + + await waitFor(() => { + const transform = result.current.reactFlowStore.getState().transform + expect(transform[0]).toBe(0) + expect(transform[1]).toBe(0) + expect(transform[2]).toBe(1) + expect(getNodeRuntimeState(result.current.nodes.find(item => item.id === 'n2'))._runningStatus).toBe(NodeRunningStatus.Running) + }) + }) + + it('updates existing tracing entry when node_id already exists', () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeStarted(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [ + { node_id: 'n0', status: NodeRunningStatus.Succeeded } as never, + { node_id: 'n1', status: NodeRunningStatus.Succeeded } as never, + ], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeStarted(createNodeStartedResponse(), containerParams) + }) + + const tracing = store.getState().workflowRunningData!.tracing! + expect(tracing).toHaveLength(2) + expect(tracing[1].status).toBe(NodeRunningStatus.Running) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-paused.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-paused.spec.ts new file mode 100644 index 0000000000..9cfb8f62d9 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-paused.spec.ts @@ -0,0 +1,15 @@ +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { WorkflowRunningStatus } from '../../../types' +import { useWorkflowPaused } from '../use-workflow-paused' + +describe('useWorkflowPaused', () => { + it('sets status to Paused', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowPaused(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + result.current.handleWorkflowPaused() + + expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Paused) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-run-event.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-run-event.spec.ts new file mode 100644 index 0000000000..fb8ea51638 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-run-event.spec.ts @@ -0,0 +1,54 @@ +import { renderHook } from '@testing-library/react' +import { useWorkflowRunEvent } from '../use-workflow-run-event' + +const handlers = vi.hoisted(() => ({ + handleWorkflowStarted: vi.fn(), + handleWorkflowFinished: vi.fn(), + handleWorkflowFailed: vi.fn(), + handleWorkflowNodeStarted: vi.fn(), + handleWorkflowNodeFinished: vi.fn(), + handleWorkflowNodeIterationStarted: vi.fn(), + handleWorkflowNodeIterationNext: vi.fn(), + handleWorkflowNodeIterationFinished: vi.fn(), + handleWorkflowNodeLoopStarted: vi.fn(), + handleWorkflowNodeLoopNext: vi.fn(), + handleWorkflowNodeLoopFinished: vi.fn(), + handleWorkflowNodeRetry: vi.fn(), + handleWorkflowTextChunk: vi.fn(), + handleWorkflowTextReplace: vi.fn(), + handleWorkflowAgentLog: vi.fn(), + handleWorkflowPaused: vi.fn(), + handleWorkflowNodeHumanInputRequired: vi.fn(), + handleWorkflowNodeHumanInputFormFilled: vi.fn(), + handleWorkflowNodeHumanInputFormTimeout: vi.fn(), +})) + +vi.mock('..', () => ({ + useWorkflowStarted: () => ({ handleWorkflowStarted: handlers.handleWorkflowStarted }), + useWorkflowFinished: () => ({ handleWorkflowFinished: handlers.handleWorkflowFinished }), + useWorkflowFailed: () => ({ handleWorkflowFailed: handlers.handleWorkflowFailed }), + useWorkflowNodeStarted: () => ({ handleWorkflowNodeStarted: handlers.handleWorkflowNodeStarted }), + useWorkflowNodeFinished: () => ({ handleWorkflowNodeFinished: handlers.handleWorkflowNodeFinished }), + useWorkflowNodeIterationStarted: () => ({ handleWorkflowNodeIterationStarted: handlers.handleWorkflowNodeIterationStarted }), + useWorkflowNodeIterationNext: () => ({ handleWorkflowNodeIterationNext: handlers.handleWorkflowNodeIterationNext }), + useWorkflowNodeIterationFinished: () => ({ handleWorkflowNodeIterationFinished: handlers.handleWorkflowNodeIterationFinished }), + useWorkflowNodeLoopStarted: () => ({ handleWorkflowNodeLoopStarted: handlers.handleWorkflowNodeLoopStarted }), + useWorkflowNodeLoopNext: () => ({ handleWorkflowNodeLoopNext: handlers.handleWorkflowNodeLoopNext }), + useWorkflowNodeLoopFinished: () => ({ handleWorkflowNodeLoopFinished: handlers.handleWorkflowNodeLoopFinished }), + useWorkflowNodeRetry: () => ({ handleWorkflowNodeRetry: handlers.handleWorkflowNodeRetry }), + useWorkflowTextChunk: () => ({ handleWorkflowTextChunk: handlers.handleWorkflowTextChunk }), + useWorkflowTextReplace: () => ({ handleWorkflowTextReplace: handlers.handleWorkflowTextReplace }), + useWorkflowAgentLog: () => ({ handleWorkflowAgentLog: handlers.handleWorkflowAgentLog }), + useWorkflowPaused: () => ({ handleWorkflowPaused: handlers.handleWorkflowPaused }), + useWorkflowNodeHumanInputRequired: () => ({ handleWorkflowNodeHumanInputRequired: handlers.handleWorkflowNodeHumanInputRequired }), + useWorkflowNodeHumanInputFormFilled: () => ({ handleWorkflowNodeHumanInputFormFilled: handlers.handleWorkflowNodeHumanInputFormFilled }), + useWorkflowNodeHumanInputFormTimeout: () => ({ handleWorkflowNodeHumanInputFormTimeout: handlers.handleWorkflowNodeHumanInputFormTimeout }), +})) + +describe('useWorkflowRunEvent', () => { + it('returns the composed handlers from all workflow event hooks', () => { + const { result } = renderHook(() => useWorkflowRunEvent()) + + expect(result.current).toEqual(handlers) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-started.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-started.spec.ts new file mode 100644 index 0000000000..4fd49c9c6a --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-started.spec.ts @@ -0,0 +1,56 @@ +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { WorkflowRunningStatus } from '../../../types' +import { useWorkflowStarted } from '../use-workflow-started' +import { + createStartedResponse, + getEdgeRuntimeState, + getNodeRuntimeState, + pausedRunningData, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowStarted', () => { + it('initializes workflow running data and resets nodes and edges', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowStarted(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowStarted(createStartedResponse()) + }) + + const state = store.getState().workflowRunningData! + expect(state.task_id).toBe('task-2') + expect(state.result.status).toBe(WorkflowRunningStatus.Running) + expect(state.resultText).toBe('') + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._waitingRun).toBe(true) + expect(getNodeRuntimeState(result.current.nodes[0])._runningBranchId).toBeUndefined() + expect(getEdgeRuntimeState(result.current.edges[0])._sourceRunningStatus).toBeUndefined() + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBeUndefined() + expect(getEdgeRuntimeState(result.current.edges[0])._waitingRun).toBe(true) + }) + }) + + it('resumes from Paused without resetting nodes or edges', () => { + const { result, store } = renderRunEventHook(() => useWorkflowStarted(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + result: pausedRunningData(), + }), + }, + }) + + act(() => { + result.current.handleWorkflowStarted(createStartedResponse({ + data: { id: 'run-2', workflow_id: 'wf-1', created_at: 2000 }, + })) + }) + + expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Running) + expect(getNodeRuntimeState(result.current.nodes[0])._waitingRun).toBe(false) + expect(getEdgeRuntimeState(result.current.edges[0])._waitingRun).toBeUndefined() + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-chunk.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-chunk.spec.ts new file mode 100644 index 0000000000..fcf36fe596 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-chunk.spec.ts @@ -0,0 +1,19 @@ +import type { TextChunkResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowTextChunk } from '../use-workflow-text-chunk' + +describe('useWorkflowTextChunk', () => { + it('appends text and activates the result tab', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowTextChunk(), { + initialStoreState: { + workflowRunningData: baseRunningData({ resultText: 'Hello' }), + }, + }) + + result.current.handleWorkflowTextChunk({ data: { text: ' World' } } as TextChunkResponse) + + const state = store.getState().workflowRunningData! + expect(state.resultText).toBe('Hello World') + expect(state.resultTabActive).toBe(true) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-replace.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-replace.spec.ts new file mode 100644 index 0000000000..f9c1dcb256 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-replace.spec.ts @@ -0,0 +1,17 @@ +import type { TextReplaceResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowTextReplace } from '../use-workflow-text-replace' + +describe('useWorkflowTextReplace', () => { + it('replaces resultText', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowTextReplace(), { + initialStoreState: { + workflowRunningData: baseRunningData({ resultText: 'old text' }), + }, + }) + + result.current.handleWorkflowTextReplace({ data: { text: 'new text' } } as TextReplaceResponse) + + expect(store.getState().workflowRunningData!.resultText).toBe('new text') + }) +}) diff --git a/web/app/components/workflow/nodes/_base/components/editor/base.tsx b/web/app/components/workflow/nodes/_base/components/editor/base.tsx index 6ed582369c..c0545ff01c 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/base.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/base.tsx @@ -84,7 +84,7 @@ const Base: FC = ({ return ( -
+
{title}
{ expect(screen.getByRole('link', { name: 'workflow.panel.helpLink' })).toHaveAttribute('href', 'https://docs.example.com/node') }) + it('should hide change action when node is undeletable', () => { + mockUseNodeMetaData.mockReturnValueOnce({ + isTypeFixed: false, + isSingleton: true, + isUndeletable: true, + description: 'Undeletable node', + author: 'Dify', + } as ReturnType) + + renderWorkflowFlowComponent( + , + { + nodes: [], + edges: [], + }, + ) + + expect(screen.getByText('workflow.panel.runThisStep')).toBeInTheDocument() + expect(screen.queryByText('workflow.panel.change')).not.toBeInTheDocument() + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + it('should render workflow-tool and readonly popup variants', () => { mockUseAllWorkflowTools.mockReturnValueOnce({ data: [{ id: 'workflow-tool', workflow_app_id: 'app-123' }], diff --git a/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx b/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx index b460aa651c..a93c3e1d14 100644 --- a/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx +++ b/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx @@ -47,7 +47,7 @@ const PanelOperatorPopup = ({ const { nodesReadOnly } = useNodesReadOnly() const edge = edges.find(edge => edge.target === id) const nodeMetaData = useNodeMetaData({ id, data } as Node) - const showChangeBlock = !nodeMetaData.isTypeFixed && !nodesReadOnly + const showChangeBlock = !nodeMetaData.isTypeFixed && !nodeMetaData.isUndeletable && !nodesReadOnly const isChildNode = !!(data.isInIteration || data.isInLoop) const { data: workflowTools } = useAllWorkflowTools() diff --git a/web/app/components/workflow/nodes/_base/components/variable/__tests__/output-var-list.spec.tsx b/web/app/components/workflow/nodes/_base/components/variable/__tests__/output-var-list.spec.tsx new file mode 100644 index 0000000000..24464e4f08 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/variable/__tests__/output-var-list.spec.tsx @@ -0,0 +1,209 @@ +import type { OutputVar } from '../../../../code/types' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import OutputVarList from '../output-var-list' + +vi.mock('../var-type-picker', () => ({ + default: (props: { value: string, onChange: (v: string) => void, readonly: boolean }) => ( + + ), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { error: vi.fn() }, +})) + +describe('OutputVarList', () => { + const createOutputs = (entries: Record = {}): OutputVar => { + const result: OutputVar = {} + for (const [key, type] of Object.entries(entries)) + result[key] = { type: type as OutputVar[string]['type'], children: null } + return result + } + + // Render the component and trigger a rename at the given index. + // Returns the newOutputs passed to onChange. + const collectRenameResult = ( + outputs: OutputVar, + outputKeyOrders: string[], + renameIndex: number, + newName: string, + ): OutputVar => { + let captured: OutputVar | undefined + + render( + { captured = newOutputs }} + onRemove={vi.fn()} + />, + ) + + const inputs = screen.getAllByRole('textbox') + fireEvent.change(inputs[renameIndex], { target: { value: newName } }) + + return captured! + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('duplicate name handling', () => { + it('should preserve outputs entry when renaming one of two duplicate-name variables', () => { + const outputs = createOutputs({ var_1: 'string' }) + const outputKeyOrders = ['var_1', 'var_1'] + + const newOutputs = collectRenameResult(outputs, outputKeyOrders, 1, '') + + // Renamed entry gets a new key '' + expect(newOutputs['']).toEqual({ type: 'string', children: null }) + // Original key 'var_1' must survive because index 0 still uses it + expect(newOutputs.var_1).toEqual({ type: 'string', children: null }) + }) + + it('should delete old key when renamed entry is the only one using it', () => { + const outputs = createOutputs({ var_1: 'string', var_2: 'number' }) + const outputKeyOrders = ['var_1', 'var_2'] + + const newOutputs = collectRenameResult(outputs, outputKeyOrders, 1, 'renamed') + + expect(newOutputs.renamed).toEqual({ type: 'number', children: null }) + expect(newOutputs.var_2).toBeUndefined() + expect(newOutputs.var_1).toEqual({ type: 'string', children: null }) + }) + + it('should keep outputs key alive when duplicate is renamed back to unique name', () => { + // Step 1: rename var_2 -> var_1 (creates duplicate) + const outputs = createOutputs({ var_1: 'string', var_2: 'number' }) + const afterFirst = collectRenameResult(outputs, ['var_1', 'var_2'], 1, 'var_1') + + expect(afterFirst.var_2).toBeUndefined() + expect(afterFirst.var_1).toBeDefined() + + // Clean up first render before the second to avoid DOM collision + cleanup() + + // Step 2: rename second var_1 -> var_2 (restores unique names) + const afterSecond = collectRenameResult(afterFirst, ['var_1', 'var_1'], 1, 'var_2') + + // var_1 must survive because index 0 still uses it + expect(afterSecond.var_1).toBeDefined() + expect(afterSecond.var_2).toBeDefined() + }) + }) + + describe('removal with duplicate names', () => { + it('should call onRemove with correct index when removing a duplicate', () => { + const outputs = createOutputs({ var_1: 'string' }) + const onRemove = vi.fn() + + render( + , + ) + + // The second remove button (index 1 in the row) + const buttons = screen.getAllByRole('button') + fireEvent.click(buttons[1]) + + expect(onRemove).toHaveBeenCalledWith(1) + }) + }) + + describe('normal operation', () => { + it('should render one row per outputKeyOrders entry', () => { + const outputs = createOutputs({ a: 'string', b: 'number' }) + const onChange = vi.fn() + + render( + , + ) + + const inputs = screen.getAllByRole('textbox') + expect(inputs).toHaveLength(2) + expect(inputs[0]).toHaveValue('a') + expect(inputs[1]).toHaveValue('b') + }) + + it('should call onChange with updated outputs when renaming', () => { + const outputs = createOutputs({ var_1: 'string' }) + const onChange = vi.fn() + + render( + , + ) + + fireEvent.change(screen.getByRole('textbox'), { target: { value: 'new_name' } }) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + new_name: { type: 'string', children: null }, + }), + 0, + 'new_name', + ) + }) + + it('should call onRemove when remove button is clicked', () => { + const outputs = createOutputs({ var_1: 'string' }) + const onRemove = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByRole('button')) + + expect(onRemove).toHaveBeenCalledWith(0) + }) + + it('should render inputs as readonly when readonly is true', () => { + const outputs = createOutputs({ var_1: 'string' }) + + render( + , + ) + + expect(screen.getByRole('textbox')).toHaveAttribute('readonly') + }) + }) +}) diff --git a/web/app/components/workflow/nodes/_base/components/variable/output-var-list.tsx b/web/app/components/workflow/nodes/_base/components/variable/output-var-list.tsx index b9a1bc524e..79238aa6de 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/output-var-list.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/output-var-list.tsx @@ -59,7 +59,9 @@ const OutputVarList: FC = ({ const newOutputs = produce(outputs, (draft) => { draft[newKey] = draft[oldKey] - delete draft[oldKey] + // Only delete old key if no other entry shares this name + if (!list.some((item, i) => i !== index && item.variable === oldKey)) + delete draft[oldKey] }) onChange(newOutputs, index, newKey) } diff --git a/web/app/components/workflow/nodes/_base/hooks/__tests__/use-toggle-expend.spec.ts b/web/app/components/workflow/nodes/_base/hooks/__tests__/use-toggle-expend.spec.ts new file mode 100644 index 0000000000..266800d9aa --- /dev/null +++ b/web/app/components/workflow/nodes/_base/hooks/__tests__/use-toggle-expend.spec.ts @@ -0,0 +1,123 @@ +import { act, renderHook } from '@testing-library/react' +import { useRef } from 'react' +import useToggleExpend from '../use-toggle-expend' + +type HookProps = { + hasFooter?: boolean + isInNode?: boolean + clientHeight?: number +} + +/** + * Wrapper that provides a real ref whose `.current.clientHeight` is stubbed + * so we can verify the height math without a real DOM layout pass. + */ +function useHarness({ hasFooter, isInNode, clientHeight = 400 }: HookProps) { + const ref = useRef(null) + + // Stub a ref-like object so measurements are deterministic. + if (!ref.current) { + Object.defineProperty(ref, 'current', { + value: { clientHeight } as HTMLDivElement, + writable: true, + }) + } + + return useToggleExpend({ ref, hasFooter, isInNode }) +} + +describe('useToggleExpend', () => { + describe('collapsed state', () => { + it('returns empty wrapClassName and zero expand height when collapsed', () => { + const { result } = renderHook(() => useHarness({ clientHeight: 400 })) + + expect(result.current.isExpand).toBe(false) + expect(result.current.wrapClassName).toBe('') + expect(result.current.editorExpandHeight).toBe(0) + }) + }) + + describe('expanded state (node context)', () => { + it('uses fixed positioning inside a workflow node panel', () => { + const { result } = renderHook(() => + useHarness({ isInNode: true, clientHeight: 400 }), + ) + + act(() => { + result.current.setIsExpand(true) + }) + + expect(result.current.isExpand).toBe(true) + expect(result.current.wrapClassName).toContain('fixed') + expect(result.current.wrapClassName).toContain('bg-components-panel-bg') + expect(result.current.wrapStyle).toEqual( + expect.objectContaining({ boxShadow: expect.any(String) }), + ) + }) + }) + + describe('expanded state (execution-log / webapp context)', () => { + it('fills its positioned ancestor edge-to-edge without hardcoded offsets', () => { + const { result } = renderHook(() => + useHarness({ isInNode: false, clientHeight: 400 }), + ) + + act(() => { + result.current.setIsExpand(true) + }) + + // The expanded panel must fill the nearest positioned ancestor entirely + // (absolute + inset-0). Previously it used hardcoded `top-[52px]` which + // assumed a 52px header that does not exist in the conversation-log + // layout, causing the expanded panel to overlap the status bar above + // the editor (#34887). + expect(result.current.wrapClassName).toContain('absolute') + expect(result.current.wrapClassName).toContain('inset-0') + expect(result.current.wrapClassName).not.toMatch(/top-\[\d+px\]/) + expect(result.current.wrapClassName).not.toMatch(/left-\d+/) + expect(result.current.wrapClassName).not.toMatch(/right-\d+/) + expect(result.current.wrapClassName).toContain('bg-components-panel-bg') + }) + }) + + describe('expanded state height math', () => { + it('subtracts the 29px chrome when hasFooter is false', () => { + const { result } = renderHook(() => + useHarness({ hasFooter: false, clientHeight: 400 }), + ) + + act(() => { + result.current.setIsExpand(true) + }) + + // 400 (clientHeight) - 29 (title bar) = 371 + expect(result.current.editorExpandHeight).toBe(371) + }) + + it('subtracts the 56px chrome when hasFooter is true', () => { + const { result } = renderHook(() => + useHarness({ hasFooter: true, clientHeight: 400 }), + ) + + act(() => { + result.current.setIsExpand(true) + }) + + // 400 (clientHeight) - 56 (title bar + footer) = 344 + expect(result.current.editorExpandHeight).toBe(344) + }) + + it('never returns a negative height even if chrome exceeds wrap', () => { + const { result } = renderHook(() => + useHarness({ hasFooter: true, clientHeight: 20 }), + ) + + act(() => { + result.current.setIsExpand(true) + }) + + // 20 - 56 would be -36; clamped to 0. + expect(result.current.editorExpandHeight).toBe(0) + }) + }) +}) diff --git a/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts b/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts index 09b8fde0b5..a77af2daef 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts @@ -134,19 +134,24 @@ function useOutputVarList({ return } + const newOutputKeyOrders = outputKeyOrders.filter((_, i) => i !== index) const newInputs = produce(inputs, (draft: any) => { - delete draft[varKey][key] + // Only delete from outputs when no remaining entry shares this name + if (!newOutputKeyOrders.includes(key)) + delete draft[varKey][key] if ((inputs as CodeNodeType).type === BlockEnum.Code && (inputs as CodeNodeType).error_strategy === ErrorHandleTypeEnum.defaultValue && varKey === 'outputs') draft.default_value = getDefaultValue(draft as any) }) setInputs(newInputs) - onOutputKeyOrdersChange(outputKeyOrders.filter((_, i) => i !== index)) - const varId = nodesWithInspectVars.find(node => node.nodeId === id)?.vars.find((varItem) => { - return varItem.name === key - })?.id - if (varId) - deleteInspectVar(id, varId) + onOutputKeyOrdersChange(newOutputKeyOrders) + if (!newOutputKeyOrders.includes(key)) { + const varId = nodesWithInspectVars.find(node => node.nodeId === id)?.vars.find((varItem) => { + return varItem.name === key + })?.id + if (varId) + deleteInspectVar(id, varId) + } }, [outputKeyOrders, isVarUsedInNodes, id, inputs, setInputs, onOutputKeyOrdersChange, nodesWithInspectVars, deleteInspectVar, showRemoveVarConfirm, varKey]) return { diff --git a/web/app/components/workflow/nodes/_base/hooks/use-toggle-expend.ts b/web/app/components/workflow/nodes/_base/hooks/use-toggle-expend.ts index c123c00e2d..1afeb8db12 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-toggle-expend.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-toggle-expend.ts @@ -1,4 +1,4 @@ -import { useEffect, useState } from 'react' +import { useLayoutEffect, useState } from 'react' type Params = { ref?: React.RefObject @@ -6,30 +6,62 @@ type Params = { isInNode?: boolean } +// Chrome (title bar + optional footer) heights subtracted from the wrap so +// the editor body never paints underneath its own controls. +const CHROME_HEIGHT_WITH_FOOTER = 56 +const CHROME_HEIGHT_WITHOUT_FOOTER = 29 + +/** + * Controls the expand/collapse behavior of the code editor wrapper used across + * workflow nodes and execution-log panels. + * + * Returns: + * - `wrapClassName` / `wrapStyle` — positioning + shadow applied to the outer + * wrapper when the editor is expanded. + * - `editorExpandHeight` — height for the editor body (wrap minus chrome). + * - `isExpand` / `setIsExpand` — state + setter for the consumer. + * + * Height is measured via `useLayoutEffect` so the first expanded render + * already has the correct value — the previous `useEffect` implementation + * left the editor at the collapsed height for one paint on first expand. + */ const useToggleExpend = ({ ref, hasFooter = true, isInNode }: Params) => { const [isExpand, setIsExpand] = useState(false) - const [wrapHeight, setWrapHeight] = useState(ref?.current?.clientHeight) - const editorExpandHeight = isExpand ? wrapHeight! - (hasFooter ? 56 : 29) : 0 - useEffect(() => { + const [wrapHeight, setWrapHeight] = useState(undefined) + + useLayoutEffect(() => { if (!ref?.current) return - setWrapHeight(ref.current?.clientHeight) - }, [isExpand]) + setWrapHeight(ref.current.clientHeight) + }, [isExpand, ref]) + + const chromeHeight = hasFooter ? CHROME_HEIGHT_WITH_FOOTER : CHROME_HEIGHT_WITHOUT_FOOTER + const editorExpandHeight = isExpand && wrapHeight !== undefined + ? Math.max(0, wrapHeight - chromeHeight) + : 0 const wrapClassName = (() => { if (!isExpand) return '' if (isInNode) - return 'fixed z-10 right-[9px] top-[166px] bottom-[8px] p-4 bg-components-panel-bg rounded-xl' + return 'fixed z-10 right-[9px] top-[166px] bottom-[8px] p-4 bg-components-panel-bg rounded-xl' - return 'absolute z-10 left-4 right-6 top-[52px] bottom-0 pb-4 bg-components-panel-bg' + // Fill the nearest positioned ancestor entirely. Previously hardcoded + // `top-[52px] left-4 right-6` offsets assumed a 52px header above the + // scroll container — that assumption no longer holds in the conversation + // log (result-panel) layout, where the status bar above the editor is + // taller than 52px, causing the expanded panel to partially overlap the + // status bar (issue #34887). + return 'absolute z-10 inset-0 pb-4 bg-components-panel-bg' })() + const wrapStyle = isExpand ? { boxShadow: '0px 0px 12px -4px rgba(16, 24, 40, 0.05), 0px -3px 6px -2px rgba(16, 24, 40, 0.03)', } : {} + return { wrapClassName, wrapStyle, diff --git a/web/app/components/workflow/nodes/agent/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/agent/__tests__/node.spec.tsx new file mode 100644 index 0000000000..025c9bd84c --- /dev/null +++ b/web/app/components/workflow/nodes/agent/__tests__/node.spec.tsx @@ -0,0 +1,249 @@ +import type { ReactNode } from 'react' +import type { AgentNodeType } from '../types' +import type useConfig from '../use-config' +import type { StrategyParamItem } from '@/app/components/plugins/types' +import { render, screen } from '@testing-library/react' +import { FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { BlockEnum } from '@/app/components/workflow/types' +import { VarType } from '../../tool/types' +import Node from '../node' + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockModelBar = vi.hoisted(() => vi.fn()) +const mockToolIcon = vi.hoisted(() => vi.fn()) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('@/hooks/use-i18n', () => ({ + useRenderI18nObject: () => (value: string | { en_US?: string }) => typeof value === 'string' ? value : value.en_US || '', +})) + +vi.mock('../components/model-bar', () => ({ + ModelBar: (props: { provider?: string, model?: string, param: string }) => { + mockModelBar(props) + return
{props.provider ? `${props.param}:${props.provider}/${props.model}` : `${props.param}:empty-model`}
+ }, +})) + +vi.mock('../components/tool-icon', () => ({ + ToolIcon: (props: { providerName: string }) => { + mockToolIcon(props) + return
{`tool:${props.providerName}`}
+ }, +})) + +vi.mock('../../_base/components/group', () => ({ + Group: ({ label, children }: { label: ReactNode, children: ReactNode }) => ( +
+
{label}
+ {children} +
+ ), + GroupLabel: ({ className, children }: { className?: string, children: ReactNode }) =>
{children}
, +})) + +vi.mock('../../_base/components/setting-item', () => ({ + SettingItem: ({ + label, + status, + tooltip, + children, + }: { + label: ReactNode + status?: string + tooltip?: string + children?: ReactNode + }) => ( +
+ {`${label}:${status || 'normal'}:${tooltip || ''}`} + {children} +
+ ), +})) + +const createStrategyParam = (overrides: Partial = {}): StrategyParamItem => ({ + name: 'requiredModel', + type: FormTypeEnum.modelSelector, + required: true, + label: { en_US: 'Required Model' } as StrategyParamItem['label'], + help: { en_US: 'Required model help' } as StrategyParamItem['help'], + placeholder: { en_US: 'Required model placeholder' } as StrategyParamItem['placeholder'], + scope: 'global', + default: null, + options: [], + template: { enabled: false }, + auto_generate: { type: 'none' }, + ...overrides, +}) + +const createData = (overrides: Partial = {}): AgentNodeType => ({ + title: 'Agent', + desc: '', + type: BlockEnum.Agent, + output_schema: {}, + agent_strategy_provider_name: 'provider/agent', + agent_strategy_name: 'react', + agent_strategy_label: 'React Agent', + plugin_unique_identifier: 'provider/agent:1.0.0', + agent_parameters: { + optionalModel: { + type: VarType.constant, + value: { provider: 'openai', model: 'gpt-4o' }, + }, + toolParam: { + type: VarType.constant, + value: { provider_name: 'author/tool-a' }, + }, + multiToolParam: { + type: VarType.constant, + value: [ + { provider_name: 'author/tool-b' }, + { provider_name: 'author/tool-c' }, + ], + }, + }, + ...overrides, +}) + +const createConfigResult = (overrides: Partial> = {}): ReturnType => ({ + readOnly: false, + inputs: createData(), + setInputs: vi.fn(), + handleVarListChange: vi.fn(), + handleAddVariable: vi.fn(), + currentStrategy: { + identity: { + author: 'provider', + name: 'react', + icon: 'icon', + label: { en_US: 'React Agent' } as StrategyParamItem['label'], + provider: 'provider/agent', + }, + parameters: [ + createStrategyParam(), + createStrategyParam({ + name: 'optionalModel', + required: false, + }), + createStrategyParam({ + name: 'toolParam', + type: FormTypeEnum.toolSelector, + required: false, + }), + createStrategyParam({ + name: 'multiToolParam', + type: FormTypeEnum.multiToolSelector, + required: false, + }), + ], + description: { en_US: 'agent description' } as StrategyParamItem['label'], + output_schema: {}, + features: [], + }, + formData: {}, + onFormChange: vi.fn(), + currentStrategyStatus: { + plugin: { source: 'marketplace', installed: true }, + isExistInPlugin: false, + }, + strategyProvider: undefined, + pluginDetail: ({ + declaration: { + label: { en_US: 'Plugin Marketplace' } as never, + }, + } as never), + availableVars: [], + availableNodesWithParent: [], + outputSchema: [], + handleMemoryChange: vi.fn(), + isChatMode: true, + ...overrides, +}) + +describe('agent/node', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('renders the not-set state when no strategy is configured', () => { + mockUseConfig.mockReturnValue(createConfigResult({ + inputs: createData({ + agent_strategy_name: undefined, + agent_strategy_label: undefined, + agent_parameters: {}, + }), + currentStrategy: undefined, + })) + + render( + , + ) + + expect(screen.getByText('workflow.nodes.agent.strategyNotSet:normal:')).toBeInTheDocument() + expect(mockModelBar).not.toHaveBeenCalled() + expect(mockToolIcon).not.toHaveBeenCalled() + }) + + it('renders strategy status, required and selected model bars, and tool icons', () => { + render( + , + ) + + expect(screen.getByText(/workflow.nodes.agent.strategy.shortLabel:error:/)).toHaveTextContent('React Agent') + expect(screen.getByText(/workflow.nodes.agent.strategy.shortLabel:error:/)).toHaveTextContent('Plugin Marketplace') + expect(screen.getByText('requiredModel:empty-model')).toBeInTheDocument() + expect(screen.getByText('optionalModel:openai/gpt-4o')).toBeInTheDocument() + expect(screen.getByText('tool:author/tool-a')).toBeInTheDocument() + expect(screen.getByText('tool:author/tool-b')).toBeInTheDocument() + expect(screen.getByText('tool:author/tool-c')).toBeInTheDocument() + expect(mockModelBar).toHaveBeenCalledTimes(2) + expect(mockToolIcon).toHaveBeenCalledTimes(3) + }) + + it('skips optional models and empty tool values when no configuration is provided', () => { + mockUseConfig.mockReturnValue(createConfigResult({ + inputs: createData({ + agent_parameters: {}, + }), + currentStrategy: { + ...createConfigResult().currentStrategy!, + parameters: [ + createStrategyParam({ + name: 'optionalModel', + required: false, + }), + createStrategyParam({ + name: 'toolParam', + type: FormTypeEnum.toolSelector, + required: false, + }), + ], + }, + currentStrategyStatus: { + plugin: { source: 'marketplace', installed: true }, + isExistInPlugin: true, + }, + })) + + render( + , + ) + + expect(mockModelBar).not.toHaveBeenCalled() + expect(mockToolIcon).not.toHaveBeenCalled() + expect(screen.queryByText('optionalModel:empty-model')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/agent/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/agent/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..15001b4757 --- /dev/null +++ b/web/app/components/workflow/nodes/agent/__tests__/panel.spec.tsx @@ -0,0 +1,297 @@ +import type { ReactNode } from 'react' +import type { AgentNodeType } from '../types' +import type useConfig from '../use-config' +import type { StrategyParamItem } from '@/app/components/plugins/types' +import type { NodePanelProps } from '@/app/components/workflow/types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { BlockEnum } from '@/app/components/workflow/types' +import Panel from '../panel' +import { AgentFeature } from '../types' + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockResetEditor = vi.hoisted(() => vi.fn()) +const mockAgentStrategy = vi.hoisted(() => vi.fn()) +const mockMemoryConfig = vi.hoisted(() => vi.fn()) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('../../../store', () => ({ + useStore: (selector: (state: { setControlPromptEditorRerenderKey: typeof mockResetEditor }) => unknown) => selector({ + setControlPromptEditorRerenderKey: mockResetEditor, + }), +})) + +vi.mock('../../_base/components/agent-strategy', () => ({ + AgentStrategy: (props: { + strategy?: { + agent_strategy_provider_name: string + agent_strategy_name: string + agent_strategy_label: string + agent_output_schema: AgentNodeType['output_schema'] + plugin_unique_identifier: string + meta?: AgentNodeType['meta'] + } + formSchema: Array<{ variable: string, tooltip?: StrategyParamItem['help'] }> + formValue: Record + onStrategyChange: (strategy: { + agent_strategy_provider_name: string + agent_strategy_name: string + agent_strategy_label: string + agent_output_schema: AgentNodeType['output_schema'] + plugin_unique_identifier: string + meta?: AgentNodeType['meta'] + }) => void + onFormValueChange: (value: Record) => void + }) => { + mockAgentStrategy(props) + return ( +
+ + +
+ ) + }, +})) + +vi.mock('../../_base/components/memory-config', () => ({ + __esModule: true, + default: (props: { + readonly?: boolean + config: { data?: AgentNodeType['memory'] } + onChange: (value?: AgentNodeType['memory']) => void + }) => { + mockMemoryConfig(props) + return ( + + ) + }, +})) + +vi.mock('../../_base/components/output-vars', () => ({ + __esModule: true, + default: ({ children }: { children: ReactNode }) =>
{children}
, + VarItem: ({ name, type, description }: { name: string, type: string, description?: string }) => ( +
{`${name}:${type}:${description || ''}`}
+ ), +})) + +const createStrategyParam = (overrides: Partial = {}): StrategyParamItem => ({ + name: 'instruction', + type: FormTypeEnum.any, + required: true, + label: { en_US: 'Instruction' } as StrategyParamItem['label'], + help: { en_US: 'Instruction help' } as StrategyParamItem['help'], + placeholder: { en_US: 'Instruction placeholder' } as StrategyParamItem['placeholder'], + scope: 'global', + default: null, + options: [], + template: { enabled: false }, + auto_generate: { type: 'none' }, + ...overrides, +}) + +const createData = (overrides: Partial = {}): AgentNodeType => ({ + title: 'Agent', + desc: '', + type: BlockEnum.Agent, + output_schema: { + properties: { + summary: { + type: 'string', + description: 'summary output', + }, + }, + }, + agent_strategy_provider_name: 'provider/agent', + agent_strategy_name: 'react', + agent_strategy_label: 'React Agent', + plugin_unique_identifier: 'provider/agent:1.0.0', + meta: { version: '1.0.0' } as AgentNodeType['meta'], + memory: { + window: { + enabled: false, + size: 3, + }, + query_prompt_template: '', + } as AgentNodeType['memory'], + ...overrides, +}) + +const createConfigResult = (overrides: Partial> = {}): ReturnType => ({ + readOnly: false, + inputs: createData(), + setInputs: vi.fn(), + handleVarListChange: vi.fn(), + handleAddVariable: vi.fn(), + currentStrategy: { + identity: { + author: 'provider', + name: 'react', + icon: 'icon', + label: { en_US: 'React Agent' } as StrategyParamItem['label'], + provider: 'provider/agent', + }, + parameters: [ + createStrategyParam(), + createStrategyParam({ + name: 'modelParam', + type: FormTypeEnum.modelSelector, + required: false, + }), + ], + description: { en_US: 'agent description' } as StrategyParamItem['label'], + output_schema: {}, + features: [AgentFeature.HISTORY_MESSAGES], + }, + formData: { + instruction: 'Plan and answer', + }, + onFormChange: vi.fn(), + currentStrategyStatus: { + plugin: { source: 'marketplace', installed: true }, + isExistInPlugin: true, + }, + strategyProvider: undefined, + pluginDetail: undefined, + availableVars: [], + availableNodesWithParent: [], + outputSchema: [{ + name: 'summary', + type: 'String', + description: 'summary output', + }], + handleMemoryChange: vi.fn(), + isChatMode: true, + ...overrides, +}) + +const panelProps = {} as NodePanelProps['panelProps'] + +describe('agent/panel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('renders strategy data, forwards strategy and form updates, and exposes output vars', async () => { + const user = userEvent.setup() + const setInputs = vi.fn() + const onFormChange = vi.fn() + const handleMemoryChange = vi.fn() + + mockUseConfig.mockReturnValue(createConfigResult({ + setInputs, + onFormChange, + handleMemoryChange, + })) + + render( + , + ) + + expect(screen.getByText('text:String:workflow.nodes.agent.outputVars.text')).toBeInTheDocument() + expect(screen.getByText('usage:object:workflow.nodes.agent.outputVars.usage')).toBeInTheDocument() + expect(screen.getByText('files:Array[File]:workflow.nodes.agent.outputVars.files.title')).toBeInTheDocument() + expect(screen.getByText('json:Array[Object]:workflow.nodes.agent.outputVars.json')).toBeInTheDocument() + expect(screen.getByText('summary:String:summary output')).toBeInTheDocument() + expect(mockAgentStrategy).toHaveBeenCalledWith(expect.objectContaining({ + formSchema: expect.arrayContaining([ + expect.objectContaining({ + variable: 'instruction', + tooltip: { en_US: 'Instruction help' }, + }), + expect.objectContaining({ + variable: 'modelParam', + }), + ]), + formValue: { + instruction: 'Plan and answer', + }, + })) + + await user.click(screen.getByRole('button', { name: 'change-strategy' })) + await user.click(screen.getByRole('button', { name: 'change-form' })) + await user.click(screen.getByRole('button', { name: 'change-memory' })) + + expect(setInputs).toHaveBeenCalledWith(expect.objectContaining({ + agent_strategy_provider_name: 'provider/updated', + agent_strategy_name: 'updated', + agent_strategy_label: 'Updated Strategy', + plugin_unique_identifier: 'provider/updated:1.0.0', + output_schema: expect.objectContaining({ + properties: expect.objectContaining({ + structured: expect.any(Object), + }), + }), + })) + expect(onFormChange).toHaveBeenCalledWith({ instruction: 'Use the tool' }) + expect(handleMemoryChange).toHaveBeenCalledWith(expect.objectContaining({ + query_prompt_template: 'history', + })) + expect(mockResetEditor).toHaveBeenCalledTimes(1) + }) + + it('hides memory config when chat mode support is unavailable', () => { + mockUseConfig.mockReturnValue(createConfigResult({ + isChatMode: false, + currentStrategy: { + ...createConfigResult().currentStrategy!, + features: [], + }, + })) + + render( + , + ) + + expect(screen.queryByRole('button', { name: 'change-memory' })).not.toBeInTheDocument() + expect(mockMemoryConfig).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/nodes/agent/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/agent/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..9e09ab6d78 --- /dev/null +++ b/web/app/components/workflow/nodes/agent/__tests__/use-config.spec.ts @@ -0,0 +1,422 @@ +import type { AgentNodeType } from '../types' +import type { StrategyParamItem } from '@/app/components/plugins/types' +import { act, renderHook, waitFor } from '@testing-library/react' +import { FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { BlockEnum, VarType as WorkflowVarType } from '@/app/components/workflow/types' +import { VarType } from '../../tool/types' +import useConfig, { useStrategyInfo } from '../use-config' + +const mockUseNodesReadOnly = vi.hoisted(() => vi.fn()) +const mockUseIsChatMode = vi.hoisted(() => vi.fn()) +const mockUseNodeCrud = vi.hoisted(() => vi.fn()) +const mockUseVarList = vi.hoisted(() => vi.fn()) +const mockUseAvailableVarList = vi.hoisted(() => vi.fn()) +const mockUseStrategyProviderDetail = vi.hoisted(() => vi.fn()) +const mockUseFetchPluginsInMarketPlaceByIds = vi.hoisted(() => vi.fn()) +const mockUseCheckInstalled = vi.hoisted(() => vi.fn()) +const mockGenerateAgentToolValue = vi.hoisted(() => vi.fn()) +const mockToolParametersToFormSchemas = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: (...args: unknown[]) => mockUseNodesReadOnly(...args), + useIsChatMode: (...args: unknown[]) => mockUseIsChatMode(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseNodeCrud(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseVarList(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-available-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseAvailableVarList(...args), +})) + +vi.mock('@/service/use-strategy', () => ({ + useStrategyProviderDetail: (...args: unknown[]) => mockUseStrategyProviderDetail(...args), +})) + +vi.mock('@/service/use-plugins', () => ({ + useFetchPluginsInMarketPlaceByIds: (...args: unknown[]) => mockUseFetchPluginsInMarketPlaceByIds(...args), + useCheckInstalled: (...args: unknown[]) => mockUseCheckInstalled(...args), +})) + +vi.mock('@/app/components/tools/utils/to-form-schema', () => ({ + generateAgentToolValue: (...args: unknown[]) => mockGenerateAgentToolValue(...args), + toolParametersToFormSchemas: (...args: unknown[]) => mockToolParametersToFormSchemas(...args), +})) + +const createStrategyParam = (overrides: Partial = {}): StrategyParamItem => ({ + name: 'instruction', + type: FormTypeEnum.any, + required: true, + label: { en_US: 'Instruction' } as StrategyParamItem['label'], + help: { en_US: 'Instruction help' } as StrategyParamItem['help'], + placeholder: { en_US: 'Instruction placeholder' } as StrategyParamItem['placeholder'], + scope: 'global', + default: null, + options: [], + template: { enabled: false }, + auto_generate: { type: 'none' }, + ...overrides, +}) + +const createToolValue = () => ({ + settings: { + api_key: 'secret', + }, + parameters: { + query: 'weather', + }, + schemas: [ + { + variable: 'api_key', + form: 'form', + }, + { + variable: 'query', + form: 'llm', + }, + ], +}) + +const createData = (overrides: Partial = {}): AgentNodeType => ({ + title: 'Agent', + desc: '', + type: BlockEnum.Agent, + output_schema: { + properties: { + summary: { + type: 'string', + description: 'summary output', + }, + items: { + type: 'array', + items: { + type: 'number', + }, + description: 'items output', + }, + }, + }, + agent_strategy_provider_name: 'provider/agent', + agent_strategy_name: 'react', + agent_strategy_label: 'React Agent', + plugin_unique_identifier: 'provider/agent:1.0.0', + agent_parameters: { + instruction: { + type: VarType.variable, + value: '#start.topic#', + }, + modelParam: { + type: VarType.constant, + value: { + provider: 'openai', + model: 'gpt-4o', + }, + }, + }, + meta: { version: '1.0.0' } as AgentNodeType['meta'], + ...overrides, +}) + +describe('agent/use-config', () => { + const providerRefetch = vi.fn() + const marketplaceRefetch = vi.fn() + const setInputs = vi.fn() + const handleVarListChange = vi.fn() + const handleAddVariable = vi.fn() + let currentInputs: AgentNodeType + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createData({ + tool_node_version: '2', + }) + + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false, getNodesReadOnly: () => false }) + mockUseIsChatMode.mockReturnValue(true) + mockUseNodeCrud.mockImplementation(() => ({ + inputs: currentInputs, + setInputs, + })) + mockUseVarList.mockReturnValue({ + handleVarListChange, + handleAddVariable, + } as never) + mockUseAvailableVarList.mockReturnValue({ + availableVars: [{ + nodeId: 'node-1', + title: 'Start', + vars: [{ + variable: 'topic', + type: WorkflowVarType.string, + }], + }], + availableNodesWithParent: [{ + nodeId: 'node-1', + title: 'Start', + }], + } as never) + mockUseStrategyProviderDetail.mockReturnValue({ + isLoading: false, + isError: false, + data: { + declaration: { + strategies: [{ + identity: { + name: 'react', + }, + parameters: [ + createStrategyParam(), + createStrategyParam({ + name: 'modelParam', + type: FormTypeEnum.modelSelector, + required: false, + }), + ], + }], + }, + }, + refetch: providerRefetch, + } as never) + mockUseFetchPluginsInMarketPlaceByIds.mockReturnValue({ + isLoading: false, + data: { + data: { + plugins: [{ id: 'provider/agent' }], + }, + }, + refetch: marketplaceRefetch, + } as never) + mockUseCheckInstalled.mockReturnValue({ + data: { + plugins: [{ + declaration: { + label: { en_US: 'Installed Agent Plugin' }, + }, + }], + }, + } as never) + mockToolParametersToFormSchemas.mockImplementation(value => value as never) + mockGenerateAgentToolValue.mockImplementation((_value, schemas, isLLM) => ({ + kind: isLLM ? 'llm' : 'setting', + fields: (schemas as Array<{ variable: string }>).map(item => item.variable), + }) as never) + }) + + it('returns an undefined strategy status while strategy data is still loading and can refetch dependencies', () => { + mockUseStrategyProviderDetail.mockReturnValue({ + isLoading: true, + isError: false, + data: undefined, + refetch: providerRefetch, + } as never) + + const { result } = renderHook(() => useStrategyInfo('provider/agent', 'react')) + + expect(result.current.strategyStatus).toBeUndefined() + expect(result.current.strategy).toBeUndefined() + + act(() => { + result.current.refetch() + }) + + expect(providerRefetch).toHaveBeenCalledTimes(1) + expect(marketplaceRefetch).toHaveBeenCalledTimes(1) + }) + + it('resolves strategy status for external plugins that are missing or not installed', () => { + mockUseStrategyProviderDetail.mockReturnValue({ + isLoading: false, + isError: true, + data: { + declaration: { + strategies: [], + }, + }, + refetch: providerRefetch, + } as never) + mockUseFetchPluginsInMarketPlaceByIds.mockReturnValue({ + isLoading: false, + data: { + data: { + plugins: [], + }, + }, + refetch: marketplaceRefetch, + } as never) + + const { result } = renderHook(() => useStrategyInfo('provider/agent', 'react')) + + expect(result.current.strategyStatus).toEqual({ + plugin: { + source: 'external', + installed: false, + }, + isExistInPlugin: false, + }) + }) + + it('exposes derived form data, strategy state, output schema, and setter helpers', () => { + const { result } = renderHook(() => useConfig('agent-node', currentInputs)) + + expect(result.current.readOnly).toBe(false) + expect(result.current.isChatMode).toBe(true) + expect(result.current.formData).toEqual({ + instruction: '#start.topic#', + modelParam: { + provider: 'openai', + model: 'gpt-4o', + }, + }) + expect(result.current.currentStrategyStatus).toEqual({ + plugin: { + source: 'marketplace', + installed: true, + }, + isExistInPlugin: true, + }) + expect(result.current.availableVars).toHaveLength(1) + expect(result.current.availableNodesWithParent).toEqual([{ + nodeId: 'node-1', + title: 'Start', + }]) + expect(result.current.outputSchema).toEqual([ + { name: 'summary', type: 'String', description: 'summary output' }, + { name: 'items', type: 'Array[Number]', description: 'items output' }, + ]) + + setInputs.mockClear() + + act(() => { + result.current.onFormChange({ + instruction: '#start.updated#', + modelParam: { + provider: 'anthropic', + model: 'claude-sonnet', + }, + }) + result.current.handleMemoryChange({ + window: { + enabled: true, + size: 6, + }, + query_prompt_template: 'history', + } as AgentNodeType['memory']) + }) + + expect(setInputs).toHaveBeenNthCalledWith(1, expect.objectContaining({ + agent_parameters: { + instruction: { + type: VarType.variable, + value: '#start.updated#', + }, + modelParam: { + type: VarType.constant, + value: { + provider: 'anthropic', + model: 'claude-sonnet', + }, + }, + }, + })) + expect(setInputs).toHaveBeenNthCalledWith(2, expect.objectContaining({ + memory: { + window: { + enabled: true, + size: 6, + }, + query_prompt_template: 'history', + }, + })) + expect(result.current.handleVarListChange).toBe(handleVarListChange) + expect(result.current.handleAddVariable).toBe(handleAddVariable) + expect(result.current.pluginDetail).toEqual({ + declaration: { + label: { en_US: 'Installed Agent Plugin' }, + }, + }) + }) + + it('formats legacy tool selector values before exposing the node config', async () => { + currentInputs = createData({ + tool_node_version: undefined, + agent_parameters: { + toolParam: { + type: VarType.constant, + value: createToolValue(), + }, + multiToolParam: { + type: VarType.constant, + value: [createToolValue()], + }, + }, + }) + mockUseStrategyProviderDetail.mockReturnValue({ + isLoading: false, + isError: false, + data: { + declaration: { + strategies: [{ + identity: { + name: 'react', + }, + parameters: [ + createStrategyParam({ + name: 'toolParam', + type: FormTypeEnum.toolSelector, + required: false, + }), + createStrategyParam({ + name: 'multiToolParam', + type: FormTypeEnum.multiToolSelector, + required: false, + }), + ], + }], + }, + }, + refetch: providerRefetch, + } as never) + + renderHook(() => useConfig('agent-node', currentInputs)) + + await waitFor(() => { + expect(setInputs).toHaveBeenCalledWith(expect.objectContaining({ + tool_node_version: '2', + agent_parameters: expect.objectContaining({ + toolParam: expect.objectContaining({ + value: expect.objectContaining({ + settings: { + kind: 'setting', + fields: ['api_key'], + }, + parameters: { + kind: 'llm', + fields: ['query'], + }, + }), + }), + multiToolParam: expect.objectContaining({ + value: [expect.objectContaining({ + settings: { + kind: 'setting', + fields: ['api_key'], + }, + parameters: { + kind: 'llm', + fields: ['query'], + }, + })], + }), + }), + })) + }) + }) +}) diff --git a/web/app/components/workflow/nodes/agent/__tests__/use-single-run-form-params.spec.ts b/web/app/components/workflow/nodes/agent/__tests__/use-single-run-form-params.spec.ts new file mode 100644 index 0000000000..33075e685f --- /dev/null +++ b/web/app/components/workflow/nodes/agent/__tests__/use-single-run-form-params.spec.ts @@ -0,0 +1,144 @@ +import type { AgentNodeType } from '../types' +import type { InputVar } from '@/app/components/workflow/types' +import { renderHook } from '@testing-library/react' +import formatTracing from '@/app/components/workflow/run/utils/format-log' +import { BlockEnum, InputVarType } from '@/app/components/workflow/types' +import useNodeCrud from '../../_base/hooks/use-node-crud' +import { VarType } from '../../tool/types' +import { useStrategyInfo } from '../use-config' +import useSingleRunFormParams from '../use-single-run-form-params' + +vi.mock('@/app/components/workflow/run/utils/format-log', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('../../_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('../use-config', async () => { + const actual = await vi.importActual('../use-config') + return { + ...actual, + useStrategyInfo: vi.fn(), + } +}) + +const mockFormatTracing = vi.mocked(formatTracing) +const mockUseNodeCrud = vi.mocked(useNodeCrud) +const mockUseStrategyInfo = vi.mocked(useStrategyInfo) + +const createData = (overrides: Partial = {}): AgentNodeType => ({ + title: 'Agent', + desc: '', + type: BlockEnum.Agent, + output_schema: {}, + agent_strategy_provider_name: 'provider/agent', + agent_strategy_name: 'react', + agent_strategy_label: 'React Agent', + agent_parameters: { + prompt: { + type: VarType.variable, + value: '#start.topic#', + }, + summary: { + type: VarType.variable, + value: '#node-2.answer#', + }, + count: { + type: VarType.constant, + value: 2, + }, + }, + ...overrides, +}) + +describe('agent/use-single-run-form-params', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodeCrud.mockReturnValue({ + inputs: createData(), + setInputs: vi.fn(), + } as unknown as ReturnType) + mockUseStrategyInfo.mockReturnValue({ + strategyProvider: undefined, + strategy: { + parameters: [ + { name: 'prompt', type: 'string' }, + { name: 'summary', type: 'string' }, + { name: 'count', type: 'number' }, + ], + }, + strategyStatus: undefined, + refetch: vi.fn(), + } as unknown as ReturnType) + mockFormatTracing.mockReturnValue([{ + id: 'agent-node', + status: 'succeeded', + }] as unknown as ReturnType) + }) + + it('builds a single-run variable form, returns node info, and skips malformed dependent vars', () => { + const setRunInputData = vi.fn() + const getInputVars = vi.fn<() => InputVar[]>(() => [ + { + label: 'Prompt', + variable: '#start.topic#', + type: InputVarType.textInput, + required: true, + }, + { + label: 'Broken', + variable: undefined as unknown as string, + type: InputVarType.textInput, + required: false, + }, + ]) + + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'agent-node', + payload: createData(), + runInputData: { topic: 'finance' }, + runInputDataRef: { current: { topic: 'finance' } }, + getInputVars, + setRunInputData, + toVarInputs: () => [], + runResult: { id: 'trace-1' } as never, + })) + + expect(getInputVars).toHaveBeenCalledWith(['#start.topic#', '#node-2.answer#']) + expect(result.current.forms).toHaveLength(1) + expect(result.current.forms[0].inputs).toHaveLength(2) + expect(result.current.forms[0].values).toEqual({ topic: 'finance' }) + expect(result.current.nodeInfo).toEqual({ + id: 'agent-node', + status: 'succeeded', + }) + + result.current.forms[0].onChange({ topic: 'updated' }) + + expect(setRunInputData).toHaveBeenCalledWith({ topic: 'updated' }) + expect(result.current.getDependentVars()).toEqual([ + ['start', 'topic'], + ]) + }) + + it('returns an empty form list when no variable input is required and no run result is available', () => { + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'agent-node', + payload: createData(), + runInputData: {}, + runInputDataRef: { current: {} }, + getInputVars: () => [], + setRunInputData: vi.fn(), + toVarInputs: () => [], + runResult: undefined as never, + })) + + expect(result.current.forms).toEqual([]) + expect(result.current.nodeInfo).toBeUndefined() + expect(result.current.getDependentVars()).toEqual([]) + }) +}) diff --git a/web/app/components/workflow/nodes/agent/components/__tests__/model-bar.spec.tsx b/web/app/components/workflow/nodes/agent/components/__tests__/model-bar.spec.tsx new file mode 100644 index 0000000000..d85f54ed19 --- /dev/null +++ b/web/app/components/workflow/nodes/agent/components/__tests__/model-bar.spec.tsx @@ -0,0 +1,78 @@ +import type { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { fireEvent, render, screen } from '@testing-library/react' +import { ModelBar } from '../model-bar' + +type ModelProviderItem = { + provider: string + models: Array<{ model: string }> +} + +const mockModelLists = new Map() + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelList: (modelType: ModelTypeEnum) => ({ + data: mockModelLists.get(modelType) || [], + }), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({ + default: ({ + defaultModel, + modelList, + }: { + defaultModel?: { provider: string, model: string } + modelList: ModelProviderItem[] + }) => ( +
+ {defaultModel ? `${defaultModel.provider}/${defaultModel.model}` : 'no-model'} + : + {modelList.length} +
+ ), +})) + +vi.mock('@/app/components/header/indicator', () => ({ + default: ({ color }: { color: string }) =>
{`indicator:${color}`}
, +})) + +describe('agent/model-bar', () => { + beforeEach(() => { + vi.clearAllMocks() + mockModelLists.clear() + mockModelLists.set('llm' as ModelTypeEnum, [{ provider: 'openai', models: [{ model: 'gpt-4o' }] }]) + mockModelLists.set('moderation' as ModelTypeEnum, []) + mockModelLists.set('rerank' as ModelTypeEnum, []) + mockModelLists.set('speech2text' as ModelTypeEnum, []) + mockModelLists.set('text-embedding' as ModelTypeEnum, []) + mockModelLists.set('tts' as ModelTypeEnum, []) + }) + + it('should render an empty readonly selector with a warning when no model is selected', () => { + render() + + const emptySelector = screen.getByText((_, element) => element?.textContent === 'no-model:0') + + fireEvent.mouseEnter(emptySelector) + + expect(emptySelector).toBeInTheDocument() + expect(screen.getByText('indicator:red')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.agent.modelNotSelected')).toBeInTheDocument() + }) + + it('should render the selected model without warning when it is installed', () => { + render() + + expect(screen.getByText('openai/gpt-4o:1')).toBeInTheDocument() + expect(screen.queryByText('indicator:red')).not.toBeInTheDocument() + }) + + it('should show a warning tooltip when the selected model is not installed', () => { + render() + + fireEvent.mouseEnter(screen.getByText('openai/gpt-4.1:1')) + + expect(screen.getByText('openai/gpt-4.1:1')).toBeInTheDocument() + expect(screen.getByText('indicator:red')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.agent.modelNotInstallTooltip')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/agent/components/__tests__/tool-icon.spec.tsx b/web/app/components/workflow/nodes/agent/components/__tests__/tool-icon.spec.tsx new file mode 100644 index 0000000000..30a12bb528 --- /dev/null +++ b/web/app/components/workflow/nodes/agent/components/__tests__/tool-icon.spec.tsx @@ -0,0 +1,113 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { ToolIcon } from '../tool-icon' + +type ToolProvider = { + id?: string + name?: string + icon?: string | { content: string, background: string } + is_team_authorization?: boolean +} + +let mockBuiltInTools: ToolProvider[] | undefined +let mockCustomTools: ToolProvider[] | undefined +let mockWorkflowTools: ToolProvider[] | undefined +let mockMcpTools: ToolProvider[] | undefined +let mockMarketplaceIcon: string | { content: string, background: string } | undefined + +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => ({ data: mockBuiltInTools }), + useAllCustomTools: () => ({ data: mockCustomTools }), + useAllWorkflowTools: () => ({ data: mockWorkflowTools }), + useAllMCPTools: () => ({ data: mockMcpTools }), +})) + +vi.mock('@/app/components/base/app-icon', () => ({ + default: ({ + icon, + background, + className, + }: { + icon?: string + background?: string + className?: string + }) =>
{`app-icon:${background}:${icon}`}
, +})) + +vi.mock('@/app/components/base/icons/src/vender/other', () => ({ + Group: ({ className }: { className?: string }) =>
group-icon
, +})) + +vi.mock('@/app/components/header/indicator', () => ({ + default: ({ color }: { color: string }) =>
{`indicator:${color}`}
, +})) + +vi.mock('@/utils/get-icon', () => ({ + getIconFromMarketPlace: () => mockMarketplaceIcon, +})) + +describe('agent/tool-icon', () => { + beforeEach(() => { + vi.clearAllMocks() + mockBuiltInTools = [] + mockCustomTools = [] + mockWorkflowTools = [] + mockMcpTools = [] + mockMarketplaceIcon = undefined + }) + + it('should render a string icon, recover from fetch errors, and keep installed tools warning-free', () => { + mockBuiltInTools = [{ + name: 'author/tool-a', + icon: 'https://example.com/tool-a.png', + is_team_authorization: true, + }] + + render() + + const icon = screen.getByRole('img', { name: 'tool icon' }) + expect(icon).toHaveAttribute('src', 'https://example.com/tool-a.png') + expect(screen.queryByText(/indicator:/)).not.toBeInTheDocument() + + fireEvent.mouseEnter(icon) + expect(screen.queryByText('workflow.nodes.agent.toolNotInstallTooltip')).not.toBeInTheDocument() + + fireEvent.error(icon) + expect(screen.getByText('group-icon')).toBeInTheDocument() + }) + + it('should render authorization and installation warnings with the correct icon sources', () => { + mockWorkflowTools = [{ + id: 'author/tool-b', + icon: { + content: 'B', + background: '#fff', + }, + is_team_authorization: false, + }] + + const { rerender } = render() + + fireEvent.mouseEnter(screen.getByText('app-icon:#fff:B')) + expect(screen.getByText('indicator:yellow')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.agent.toolNotAuthorizedTooltip:{"tool":"tool-b"}')).toBeInTheDocument() + + mockWorkflowTools = [] + mockMarketplaceIcon = 'https://example.com/market-tool.png' + rerender() + + const marketplaceIcon = screen.getByRole('img', { name: 'tool icon' }) + fireEvent.mouseEnter(marketplaceIcon) + expect(marketplaceIcon).toHaveAttribute('src', 'https://example.com/market-tool.png') + expect(screen.getByText('indicator:red')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.agent.toolNotInstallTooltip:{"tool":"tool-c"}')).toBeInTheDocument() + }) + + it('should fall back to the group icon while tool data is still loading', () => { + mockBuiltInTools = undefined + + render() + + expect(screen.getByText('group-icon')).toBeInTheDocument() + expect(screen.queryByText(/indicator:/)).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/answer/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/answer/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..b5fbdf163f --- /dev/null +++ b/web/app/components/workflow/nodes/answer/__tests__/panel.spec.tsx @@ -0,0 +1,92 @@ +import type { AnswerNodeType } from '../types' +import type { PanelProps } from '@/types/workflow' +import { fireEvent, render, screen } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import Panel from '../panel' + +type MockEditorProps = { + readOnly: boolean + title: string + value: string + onChange: (value: string) => void + nodesOutputVars: unknown[] + availableNodes: unknown[] +} + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockUseAvailableVarList = vi.hoisted(() => vi.fn()) +const mockEditorRender = vi.hoisted(() => vi.fn()) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-available-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseAvailableVarList(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/prompt/editor', () => ({ + __esModule: true, + default: (props: MockEditorProps) => { + mockEditorRender(props) + return ( + + ) + }, +})) + +const createData = (overrides: Partial = {}): AnswerNodeType => ({ + title: 'Answer', + desc: '', + type: BlockEnum.Answer, + variables: [], + answer: 'Initial answer', + ...overrides, +}) + +describe('AnswerPanel', () => { + const handleAnswerChange = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseConfig.mockReturnValue({ + readOnly: false, + inputs: createData(), + handleAnswerChange, + filterVar: vi.fn(), + }) + mockUseAvailableVarList.mockReturnValue({ + availableVars: [{ variable: 'context', type: 'string' }], + availableNodesWithParent: [{ value: 'node-1', label: 'Node 1' }], + }) + }) + + it('should pass editor state and available variables through to the prompt editor', () => { + render() + + expect(screen.getByRole('button', { name: 'workflow.nodes.answer.answer:Initial answer' })).toBeInTheDocument() + expect(mockEditorRender).toHaveBeenCalledWith(expect.objectContaining({ + readOnly: false, + title: 'workflow.nodes.answer.answer', + value: 'Initial answer', + nodesOutputVars: [{ variable: 'context', type: 'string' }], + availableNodes: [{ value: 'node-1', label: 'Node 1' }], + isSupportFileVar: true, + justVar: true, + })) + }) + + it('should delegate answer edits to use-config', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'workflow.nodes.answer.answer:Initial answer' })) + + expect(handleAnswerChange).toHaveBeenCalledWith('Updated answer') + }) +}) diff --git a/web/app/components/workflow/nodes/answer/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/answer/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..106355e8c5 --- /dev/null +++ b/web/app/components/workflow/nodes/answer/__tests__/use-config.spec.ts @@ -0,0 +1,81 @@ +import type { AnswerNodeType } from '../types' +import { act, renderHook } from '@testing-library/react' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import useConfig from '../use-config' + +const mockUseNodesReadOnly = vi.hoisted(() => vi.fn()) +const mockUseNodeCrud = vi.hoisted(() => vi.fn()) +const mockUseVarList = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: () => mockUseNodesReadOnly(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseNodeCrud(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseVarList(...args), +})) + +const createPayload = (overrides: Partial = {}): AnswerNodeType => ({ + title: 'Answer', + desc: '', + type: BlockEnum.Answer, + variables: [], + answer: 'Initial answer', + ...overrides, +}) + +describe('answer/use-config', () => { + const mockSetInputs = vi.fn() + const mockHandleVarListChange = vi.fn() + const mockHandleAddVariable = vi.fn() + let currentInputs: AnswerNodeType + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createPayload() + + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false }) + mockUseNodeCrud.mockReturnValue({ + inputs: currentInputs, + setInputs: mockSetInputs, + }) + mockUseVarList.mockReturnValue({ + handleVarListChange: mockHandleVarListChange, + handleAddVariable: mockHandleAddVariable, + }) + }) + + it('should update the answer text and expose var-list handlers', () => { + const { result } = renderHook(() => useConfig('answer-node', currentInputs)) + + act(() => { + result.current.handleAnswerChange('Updated answer') + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + answer: 'Updated answer', + })) + expect(result.current.handleVarListChange).toBe(mockHandleVarListChange) + expect(result.current.handleAddVariable).toBe(mockHandleAddVariable) + expect(result.current.readOnly).toBe(false) + }) + + it('should filter out array-object variables from the prompt editor picker', () => { + const { result } = renderHook(() => useConfig('answer-node', currentInputs)) + + expect(result.current.filterVar({ + variable: 'items', + type: VarType.arrayObject, + })).toBe(false) + expect(result.current.filterVar({ + variable: 'message', + type: VarType.string, + })).toBe(true) + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/assigner/__tests__/node.spec.tsx new file mode 100644 index 0000000000..a1fd87d386 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/__tests__/node.spec.tsx @@ -0,0 +1,150 @@ +import type { AssignerNodeOperation, AssignerNodeType } from '../types' +import { render, screen } from '@testing-library/react' +import { useNodes } from 'reactflow' +import { BlockEnum } from '@/app/components/workflow/types' +import Node from '../node' +import { AssignerNodeInputType, WriteMode } from '../types' + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useNodes: vi.fn(), + } +}) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/variable-label', () => ({ + VariableLabelInNode: ({ + variables, + nodeTitle, + nodeType, + rightSlot, + }: { + variables: string[] + nodeTitle?: string + nodeType?: BlockEnum + rightSlot?: React.ReactNode + }) => ( +
+ {`${nodeTitle}:${nodeType}:${variables.join('.')}`} + {rightSlot} +
+ ), +})) + +const mockUseNodes = vi.mocked(useNodes) + +const createOperation = (overrides: Partial = {}): AssignerNodeOperation => ({ + variable_selector: ['node-1', 'count'], + input_type: AssignerNodeInputType.variable, + operation: WriteMode.overwrite, + value: ['node-2', 'result'], + ...overrides, +}) + +const createData = (overrides: Partial = {}): AssignerNodeType => ({ + title: 'Assigner', + desc: '', + type: BlockEnum.VariableAssigner, + version: '2', + items: [createOperation()], + ...overrides, +}) + +describe('assigner/node', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodes.mockReturnValue([ + { + id: 'node-1', + data: { + title: 'Answer', + type: BlockEnum.Answer, + }, + }, + { + id: 'start-node', + data: { + title: 'Start', + type: BlockEnum.Start, + }, + }, + ] as ReturnType) + }) + + it('renders the empty-state hint when no assignable variable is configured', () => { + render( + , + ) + + expect(screen.getByText('workflow.nodes.assigner.varNotSet')).toBeInTheDocument() + }) + + it('renders both version 2 and legacy previews with resolved node labels', () => { + const { container, rerender } = render( + , + ) + + expect(screen.getByText('Answer:answer:node-1.count')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.assigner.operations.over-write')).toBeInTheDocument() + + rerender( + , + ) + + expect(screen.getByText('Start:start:sys.query')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.assigner.operations.append')).toBeInTheDocument() + + rerender( + , + ) + + expect(container).toBeEmptyDOMElement() + }) + + it('skips empty v2 operations and resolves system variables through the start node', () => { + render( + , + ) + + expect(screen.getByText('Start:start:sys.query')).toBeInTheDocument() + expect(screen.queryByText('undefined:undefined:')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/assigner/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..c70c84beab --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/__tests__/panel.spec.tsx @@ -0,0 +1,119 @@ +import type { AssignerNodeOperation, AssignerNodeType } from '../types' +import type { PanelProps } from '@/types/workflow' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { BlockEnum } from '@/app/components/workflow/types' +import Panel from '../panel' +import { AssignerNodeInputType, WriteMode } from '../types' + +type MockVarListProps = { + readonly: boolean + nodeId: string + list: AssignerNodeOperation[] + onChange: (list: AssignerNodeOperation[]) => void +} + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockUseHandleAddOperationItem = vi.hoisted(() => vi.fn()) +const mockVarListRender = vi.hoisted(() => vi.fn()) + +const createOperation = (overrides: Partial = {}): AssignerNodeOperation => ({ + variable_selector: ['node-1', 'count'], + input_type: AssignerNodeInputType.variable, + operation: WriteMode.overwrite, + value: ['node-2', 'result'], + ...overrides, +}) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('../hooks', () => ({ + useHandleAddOperationItem: () => mockUseHandleAddOperationItem, +})) + +vi.mock('../components/var-list', () => ({ + __esModule: true, + default: (props: MockVarListProps) => { + mockVarListRender(props) + return ( +
+
{props.list.map(item => item.variable_selector.join('.')).join(',')}
+ +
+ ) + }, +})) + +const createData = (overrides: Partial = {}): AssignerNodeType => ({ + title: 'Assigner', + desc: '', + type: BlockEnum.VariableAssigner, + version: '2', + items: [createOperation()], + ...overrides, +}) + +const panelProps = {} as PanelProps + +describe('assigner/panel', () => { + const handleOperationListChanges = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseHandleAddOperationItem.mockReturnValue([ + createOperation(), + createOperation({ variable_selector: [] }), + ]) + mockUseConfig.mockReturnValue({ + readOnly: false, + inputs: createData(), + handleOperationListChanges, + getAssignedVarType: vi.fn(), + getToAssignedVarType: vi.fn(), + writeModeTypesNum: [], + writeModeTypesArr: [], + writeModeTypes: [], + filterAssignedVar: vi.fn(), + filterToAssignedVar: vi.fn(), + }) + }) + + it('passes the resolved config to the variable list and appends operations through the add button', async () => { + const user = userEvent.setup() + + render( + , + ) + + expect(screen.getByText('workflow.nodes.assigner.variables')).toBeInTheDocument() + expect(screen.getByText('node-1.count')).toBeInTheDocument() + expect(mockVarListRender).toHaveBeenCalledWith(expect.objectContaining({ + readonly: false, + nodeId: 'assigner-node', + list: createData().items, + })) + + await user.click(screen.getAllByRole('button')[0]!) + + expect(mockUseHandleAddOperationItem).toHaveBeenCalledWith(createData().items) + expect(handleOperationListChanges).toHaveBeenCalledWith([ + createOperation(), + createOperation({ variable_selector: [] }), + ]) + + await user.click(screen.getByRole('button', { name: 'emit-list-change' })) + + expect(handleOperationListChanges).toHaveBeenCalledWith([ + createOperation({ variable_selector: ['node-1', 'updated'] }), + ]) + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/__tests__/use-single-run-form-params.spec.ts b/web/app/components/workflow/nodes/assigner/__tests__/use-single-run-form-params.spec.ts new file mode 100644 index 0000000000..0551d1fd30 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/__tests__/use-single-run-form-params.spec.ts @@ -0,0 +1,85 @@ +import type { AssignerNodeOperation, AssignerNodeType } from '../types' +import type { InputVar } from '@/app/components/workflow/types' +import { renderHook } from '@testing-library/react' +import { BlockEnum, InputVarType } from '@/app/components/workflow/types' +import useNodeCrud from '../../_base/hooks/use-node-crud' +import { AssignerNodeInputType, WriteMode } from '../types' +import useSingleRunFormParams from '../use-single-run-form-params' + +vi.mock('../../_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +const mockUseNodeCrud = vi.mocked(useNodeCrud) + +const createOperation = (overrides: Partial = {}): AssignerNodeOperation => ({ + variable_selector: ['node-1', 'target'], + input_type: AssignerNodeInputType.variable, + operation: WriteMode.overwrite, + value: ['node-2', 'result'], + ...overrides, +}) + +const createData = (overrides: Partial = {}): AssignerNodeType => ({ + title: 'Assigner', + desc: '', + type: BlockEnum.VariableAssigner, + version: '2', + items: [ + createOperation(), + createOperation({ operation: WriteMode.append, value: ['node-3', 'items'] }), + createOperation({ operation: WriteMode.clear, value: ['node-4', 'unused'] }), + createOperation({ operation: WriteMode.set, input_type: AssignerNodeInputType.constant, value: 'fixed' }), + createOperation({ operation: WriteMode.increment, input_type: AssignerNodeInputType.constant, value: 2 }), + ], + ...overrides, +}) + +describe('assigner/use-single-run-form-params', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodeCrud.mockReturnValue({ + inputs: createData(), + setInputs: vi.fn(), + } as unknown as ReturnType) + }) + + it('exposes only variable-driven dependencies in the single-run form', () => { + const setRunInputData = vi.fn() + const varInputs: InputVar[] = [{ + label: 'Result', + variable: 'result', + type: InputVarType.textInput, + required: true, + }] + const varSelectorsToVarInputs = vi.fn(() => varInputs) + + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'assigner-node', + payload: createData(), + runInputData: { result: 'hello' }, + runInputDataRef: { current: {} }, + getInputVars: () => [], + setRunInputData, + toVarInputs: () => [], + varSelectorsToVarInputs, + })) + + expect(varSelectorsToVarInputs).toHaveBeenCalledWith([ + ['node-2', 'result'], + ['node-3', 'items'], + ]) + expect(result.current.forms).toHaveLength(1) + expect(result.current.forms[0].inputs).toEqual(varInputs) + expect(result.current.forms[0].values).toEqual({ result: 'hello' }) + + result.current.forms[0].onChange({ result: 'updated' }) + + expect(setRunInputData).toHaveBeenCalledWith({ result: 'updated' }) + expect(result.current.getDependentVars()).toEqual([ + ['node-2', 'result'], + ['node-3', 'items'], + ]) + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/components/__tests__/operation-selector.spec.tsx b/web/app/components/workflow/nodes/assigner/components/__tests__/operation-selector.spec.tsx new file mode 100644 index 0000000000..63813c8a46 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/components/__tests__/operation-selector.spec.tsx @@ -0,0 +1,52 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { VarType } from '@/app/components/workflow/types' +import { WriteMode } from '../../types' +import OperationSelector from '../operation-selector' + +describe('assigner/operation-selector', () => { + it('shows numeric write modes and emits the selected operation', async () => { + const user = userEvent.setup() + const onSelect = vi.fn() + + render( + , + ) + + await user.click(screen.getByText('workflow.nodes.assigner.operations.over-write')) + + expect(screen.getByText('workflow.nodes.assigner.operations.title')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.assigner.operations.clear')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.assigner.operations.set')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.assigner.operations.+=')).toBeInTheDocument() + + await user.click(screen.getAllByText('workflow.nodes.assigner.operations.+=').at(-1)!) + + expect(onSelect).toHaveBeenCalledWith({ value: WriteMode.increment, name: WriteMode.increment }) + }) + + it('does not open when the selector is disabled', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByText('workflow.nodes.assigner.operations.over-write')) + + expect(screen.queryByText('workflow.nodes.assigner.operations.title')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/branches.spec.tsx b/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/branches.spec.tsx new file mode 100644 index 0000000000..a9b5a304f4 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/branches.spec.tsx @@ -0,0 +1,213 @@ +import type { ComponentProps } from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { VarType } from '@/app/components/workflow/types' +import { AssignerNodeInputType, WriteMode } from '../../../types' +import VarList from '../index' + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => ({ + __esModule: true, + default: ({ + popupFor = 'assigned', + onOpen, + onChange, + }: { + popupFor?: string + onOpen?: () => void + onChange: (value: string[]) => void + }) => ( +
+ + +
+ ), +})) + +vi.mock('../../operation-selector', () => ({ + __esModule: true, + default: ({ + onSelect, + }: { + onSelect: (item: { value: string }) => void + }) => ( +
+ + +
+ ), +})) + +const createOperation = ( + overrides: Partial['list'][number]> = {}, +): ComponentProps['list'][number] => ({ + variable_selector: ['node-a', 'flag'], + input_type: AssignerNodeInputType.variable, + operation: WriteMode.overwrite, + value: ['node-a', 'answer'], + ...overrides, +}) + +const renderVarList = (props: Partial> = {}) => { + const handleChange = vi.fn() + const handleOpen = vi.fn() + + const result = render( + VarType.string} + getToAssignedVarType={() => VarType.string} + writeModeTypes={[WriteMode.overwrite, WriteMode.clear, WriteMode.set]} + writeModeTypesArr={[WriteMode.overwrite, WriteMode.clear]} + writeModeTypesNum={[WriteMode.increment]} + {...props} + />, + ) + + return { + ...result, + handleChange, + handleOpen, + } +} + +describe('assigner/var-list branches', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('resets operation metadata when the assigned variable changes', async () => { + const user = userEvent.setup() + const { handleChange, handleOpen } = renderVarList({ + list: [createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: 'stale', + })], + }) + + await user.click(screen.getByTestId('assigned-picker-trigger')) + await user.click(screen.getByRole('button', { name: 'select-assigned' })) + + expect(handleOpen).toHaveBeenCalledWith(0) + expect(handleChange).toHaveBeenLastCalledWith([ + createOperation({ + variable_selector: ['node-b', 'total'], + operation: WriteMode.overwrite, + input_type: AssignerNodeInputType.variable, + value: undefined, + }), + ], ['node-b', 'total']) + }) + + it('switches back to variable mode when the selected operation no longer requires a constant', async () => { + const user = userEvent.setup() + const { handleChange } = renderVarList({ + list: [createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: 'hello', + })], + }) + + await user.click(screen.getByRole('button', { name: 'operation-overwrite' })) + + expect(handleChange).toHaveBeenLastCalledWith([ + createOperation({ + operation: WriteMode.overwrite, + input_type: AssignerNodeInputType.variable, + value: '', + }), + ]) + }) + + it('updates string and number constant inputs through the inline editors', () => { + const { handleChange, rerender } = renderVarList({ + list: [createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: 1, + })], + getAssignedVarType: () => VarType.number, + getToAssignedVarType: () => VarType.number, + }) + + fireEvent.change(screen.getByRole('spinbutton'), { + target: { value: '2' }, + }) + + expect(handleChange).toHaveBeenLastCalledWith([ + createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: 2, + }), + ], 2) + + rerender( + VarType.string} + getToAssignedVarType={() => VarType.string} + writeModeTypes={[WriteMode.overwrite, WriteMode.clear, WriteMode.set]} + writeModeTypesArr={[WriteMode.overwrite, WriteMode.clear]} + writeModeTypesNum={[WriteMode.increment]} + />, + ) + + fireEvent.change(screen.getByRole('textbox'), { + target: { value: 'updated' }, + }) + + expect(handleChange).toHaveBeenLastCalledWith([ + createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: 'updated', + }), + ], 'updated') + }) + + it('updates numeric write-mode inputs through the dedicated number field', () => { + const { handleChange } = renderVarList({ + list: [createOperation({ + operation: WriteMode.increment, + value: 2, + })], + getAssignedVarType: () => VarType.number, + getToAssignedVarType: () => VarType.number, + writeModeTypesNum: [WriteMode.increment], + }) + + fireEvent.change(screen.getByRole('spinbutton'), { + target: { value: '5' }, + }) + + expect(handleChange).toHaveBeenLastCalledWith([ + createOperation({ + operation: WriteMode.increment, + value: 5, + }), + ], 5) + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/index.spec.tsx b/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/index.spec.tsx new file mode 100644 index 0000000000..f7408ab814 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/index.spec.tsx @@ -0,0 +1,146 @@ +import type { ComponentProps } from 'react' +import { fireEvent, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { createNode, resetFixtureCounters } from '@/app/components/workflow/__tests__/fixtures' +import { renderWorkflowFlowComponent } from '@/app/components/workflow/__tests__/workflow-test-env' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import { AssignerNodeInputType, WriteMode } from '../../../types' +import VarList from '../index' + +const sourceNode = createNode({ + id: 'node-a', + data: { + type: BlockEnum.Answer, + title: 'Answer Node', + outputs: { + answer: { type: VarType.string }, + flag: { type: VarType.boolean }, + }, + }, +}) + +const currentNode = createNode({ + id: 'node-current', + data: { + type: BlockEnum.VariableAssigner, + title: 'Assigner Node', + }, +}) + +const createOperation = (overrides: Partial['list'][number]> = {}) => ({ + variable_selector: ['node-a', 'flag'], + input_type: AssignerNodeInputType.variable, + operation: WriteMode.overwrite, + value: ['node-a', 'answer'], + ...overrides, +}) + +const renderVarList = (props: Partial> = {}) => { + const handleChange = vi.fn() + const handleOpen = vi.fn() + + const result = renderWorkflowFlowComponent( + VarType.string} + getToAssignedVarType={() => VarType.string} + writeModeTypes={[WriteMode.overwrite, WriteMode.clear, WriteMode.set]} + writeModeTypesArr={[WriteMode.overwrite, WriteMode.clear]} + writeModeTypesNum={[WriteMode.increment]} + {...props} + />, + { + nodes: [sourceNode, currentNode], + edges: [], + hooksStoreProps: {}, + }, + ) + + return { + ...result, + handleChange, + handleOpen, + } +} + +describe('assigner/var-list', () => { + beforeEach(() => { + resetFixtureCounters() + }) + + it('renders the empty placeholder when no operations are configured', () => { + renderVarList() + + expect(screen.getByText('workflow.nodes.assigner.noVarTip')).toBeInTheDocument() + }) + + it('switches a boolean assignment to constant mode and updates the selected value', async () => { + const user = userEvent.setup() + const list = [createOperation()] + const { handleChange, rerender } = renderVarList({ + list, + getAssignedVarType: () => VarType.boolean, + getToAssignedVarType: () => VarType.boolean, + }) + + await user.click(screen.getByText('workflow.nodes.assigner.operations.over-write')) + await user.click(screen.getAllByText('workflow.nodes.assigner.operations.set').at(-1)!) + + expect(handleChange.mock.lastCall?.[0]).toEqual([ + createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: false, + }), + ]) + + rerender( + VarType.boolean} + getToAssignedVarType={() => VarType.boolean} + writeModeTypes={[WriteMode.overwrite, WriteMode.clear, WriteMode.set]} + writeModeTypesArr={[WriteMode.overwrite, WriteMode.clear]} + writeModeTypesNum={[WriteMode.increment]} + />, + ) + + await user.click(screen.getByText('True')) + + expect(handleChange.mock.lastCall?.[0]).toEqual([ + createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: true, + }), + ]) + }) + + it('opens the assigned-variable picker and removes an operation', () => { + const { handleChange, handleOpen } = renderVarList({ + list: [createOperation()], + }) + + fireEvent.click(screen.getAllByTestId('var-reference-picker-trigger')[0]!) + expect(handleOpen).toHaveBeenCalledWith(0) + + const buttons = screen.getAllByRole('button') + fireEvent.click(buttons[buttons.length - 1]!) + + expect(handleChange).toHaveBeenLastCalledWith([]) + }) +}) diff --git a/web/app/components/workflow/nodes/code/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/code/__tests__/node.spec.tsx new file mode 100644 index 0000000000..a8648324ed --- /dev/null +++ b/web/app/components/workflow/nodes/code/__tests__/node.spec.tsx @@ -0,0 +1,29 @@ +import type { CodeNodeType } from '../types' +import { render } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import Node from '../node' +import { CodeLanguage } from '../types' + +const createData = (overrides: Partial = {}): CodeNodeType => ({ + title: 'Code', + desc: '', + type: BlockEnum.Code, + variables: [], + code_language: CodeLanguage.javascript, + code: 'function main() { return {} }', + outputs: {}, + ...overrides, +}) + +describe('code/node', () => { + it('renders an empty summary container', () => { + const { container } = render( + , + ) + + expect(container.firstChild).toBeEmptyDOMElement() + }) +}) diff --git a/web/app/components/workflow/nodes/code/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/code/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..72d640651d --- /dev/null +++ b/web/app/components/workflow/nodes/code/__tests__/panel.spec.tsx @@ -0,0 +1,295 @@ +import type { ReactNode } from 'react' +import type { CodeNodeType, OutputVar } from '../types' +import type useConfig from '../use-config' +import type { NodePanelProps, Variable } from '@/app/components/workflow/types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import Panel from '../panel' +import { CodeLanguage } from '../types' + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockExtractFunctionParams = vi.hoisted(() => vi.fn()) +const mockExtractReturnType = vi.hoisted(() => vi.fn()) +const mockCodeEditor = vi.hoisted(() => vi.fn()) +const mockVarList = vi.hoisted(() => vi.fn()) +const mockOutputVarList = vi.hoisted(() => vi.fn()) +const mockRemoveEffectVarConfirm = vi.hoisted(() => vi.fn()) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('../code-parser', () => ({ + extractFunctionParams: (...args: unknown[]) => mockExtractFunctionParams(...args), + extractReturnType: (...args: unknown[]) => mockExtractReturnType(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({ + __esModule: true, + default: (props: { + readOnly: boolean + language: CodeLanguage + value: string + onChange: (value: string) => void + onGenerated: (value: string) => void + title: ReactNode + }) => { + mockCodeEditor(props) + return ( +
+
{props.readOnly ? 'editor:readonly' : 'editor:editable'}
+
{props.language}
+
{props.title}
+ + +
+ ) + }, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/selector', () => ({ + __esModule: true, + default: (props: { + value: CodeLanguage + onChange: (value: CodeLanguage) => void + }) => ( + + ), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-list', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + list: Variable[] + onChange: (list: Variable[]) => void + }) => { + mockVarList(props) + return ( +
+
{props.readonly ? 'var-list:readonly' : 'var-list:editable'}
+ +
+ ) + }, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/output-var-list', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + outputs: OutputVar + onChange: (outputs: OutputVar) => void + onRemove: (name: string) => void + }) => { + mockOutputVarList(props) + return ( +
+
{props.readonly ? 'output-list:readonly' : 'output-list:editable'}
+ + +
+ ) + }, +})) + +vi.mock('../../_base/components/remove-effect-var-confirm', () => ({ + __esModule: true, + default: (props: { + isShow: boolean + onCancel: () => void + onConfirm: () => void + }) => { + mockRemoveEffectVarConfirm(props) + return props.isShow + ? ( +
+ + +
+ ) + : null + }, +})) + +const createData = (overrides: Partial = {}): CodeNodeType => ({ + title: 'Code', + desc: '', + type: BlockEnum.Code, + code_language: CodeLanguage.javascript, + code: 'function main({ foo }) { return { result: foo } }', + variables: [{ + variable: 'foo', + value_selector: ['start', 'foo'], + value_type: VarType.string, + }], + outputs: { + result: { + type: VarType.string, + children: null, + }, + }, + ...overrides, +}) + +const createConfigResult = (overrides: Partial> = {}): ReturnType => ({ + readOnly: false, + inputs: createData(), + outputKeyOrders: ['result'], + handleCodeAndVarsChange: vi.fn(), + handleVarListChange: vi.fn(), + handleAddVariable: vi.fn(), + handleRemoveVariable: vi.fn(), + handleSyncFunctionSignature: vi.fn(), + handleCodeChange: vi.fn(), + handleCodeLanguageChange: vi.fn(), + handleVarsChange: vi.fn(), + handleAddOutputVariable: vi.fn(), + filterVar: vi.fn(() => true), + isShowRemoveVarConfirm: true, + hideRemoveVarConfirm: vi.fn(), + onRemoveVarConfirm: vi.fn(), + ...overrides, +}) + +const renderPanel = (data: CodeNodeType = createData()) => { + const props: NodePanelProps = { + id: 'code-node', + data, + panelProps: { + getInputVars: vi.fn(() => []), + toVarInputs: vi.fn(() => []), + runInputData: {}, + runInputDataRef: { current: {} }, + setRunInputData: vi.fn(), + runResult: null, + }, + } + + return render() +} + +describe('code/panel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockExtractFunctionParams.mockReturnValue(['summary', 'count']) + mockExtractReturnType.mockReturnValue({ + result: { + type: VarType.string, + children: null, + }, + }) + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('renders editable controls and forwards all input, output, and code actions', async () => { + const user = userEvent.setup() + const config = createConfigResult() + mockUseConfig.mockReturnValue(config) + + renderPanel() + + expect(screen.getByText('workflow.nodes.code.inputVars')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.code.outputVars')).toBeInTheDocument() + expect(screen.getByText('editor:editable')).toBeInTheDocument() + expect(screen.getByText('language:javascript')).toBeInTheDocument() + + const addButtons = screen.getAllByTestId('add-button') + await user.click(addButtons[0]!) + await user.click(screen.getByTestId('sync-button')) + await user.click(screen.getByRole('button', { name: 'change-code' })) + await user.click(screen.getByRole('button', { name: 'generate-code' })) + await user.click(screen.getByRole('button', { name: 'language:javascript' })) + await user.click(screen.getByRole('button', { name: 'change-var-list' })) + await user.click(screen.getByRole('button', { name: 'change-output-list' })) + await user.click(screen.getByRole('button', { name: 'remove-output' })) + await user.click(addButtons[1]!) + await user.click(screen.getByRole('button', { name: 'cancel-remove' })) + await user.click(screen.getByRole('button', { name: 'confirm-remove' })) + + expect(config.handleAddVariable).toHaveBeenCalled() + expect(config.handleSyncFunctionSignature).toHaveBeenCalled() + expect(config.handleCodeChange).toHaveBeenCalledWith('generated code body') + expect(config.handleCodeLanguageChange).toHaveBeenCalledWith(CodeLanguage.python3) + expect(config.handleVarListChange).toHaveBeenCalledWith([{ + variable: 'changed', + value_selector: ['start', 'changed'], + }]) + expect(config.handleVarsChange).toHaveBeenCalledWith({ + next_result: { + type: VarType.number, + children: null, + }, + }) + expect(config.handleRemoveVariable).toHaveBeenCalledWith('result') + expect(config.handleAddOutputVariable).toHaveBeenCalled() + expect(config.hideRemoveVarConfirm).toHaveBeenCalled() + expect(config.onRemoveVarConfirm).toHaveBeenCalled() + expect(config.handleCodeAndVarsChange).toHaveBeenCalledWith( + 'generated signature code', + [{ + variable: 'summary', + value_selector: [], + }, { + variable: 'count', + value_selector: [], + }], + { + result: { + type: VarType.string, + children: null, + }, + }, + ) + expect(mockExtractFunctionParams).toHaveBeenCalledWith('generated signature code', CodeLanguage.javascript) + expect(mockExtractReturnType).toHaveBeenCalledWith('generated signature code', CodeLanguage.javascript) + }) + + it('removes input actions in readonly mode and passes readonly state to child sections', () => { + mockUseConfig.mockReturnValue(createConfigResult({ + readOnly: true, + isShowRemoveVarConfirm: false, + })) + + renderPanel() + + expect(screen.queryByTestId('sync-button')).not.toBeInTheDocument() + expect(screen.getAllByTestId('add-button')).toHaveLength(1) + expect(screen.getByText('editor:readonly')).toBeInTheDocument() + expect(screen.getByText('var-list:readonly')).toBeInTheDocument() + expect(screen.getByText('output-list:readonly')).toBeInTheDocument() + expect(mockRemoveEffectVarConfirm).toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/nodes/code/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/code/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..b02ff8a4fc --- /dev/null +++ b/web/app/components/workflow/nodes/code/__tests__/use-config.spec.ts @@ -0,0 +1,315 @@ +import type { CodeNodeType, OutputVar } from '../types' +import type { Var, Variable } from '@/app/components/workflow/types' +import { act, renderHook, waitFor } from '@testing-library/react' +import { useNodesReadOnly } from '@/app/components/workflow/hooks' +import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' +import { useStore } from '@/app/components/workflow/store' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import { fetchNodeDefault, fetchPipelineNodeDefault } from '@/service/workflow' +import useOutputVarList from '../../_base/hooks/use-output-var-list' +import useVarList from '../../_base/hooks/use-var-list' +import { CodeLanguage } from '../types' +import useConfig from '../use-config' + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-var-list', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-output-var-list', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: vi.fn(), +})) + +vi.mock('@/service/workflow', () => ({ + fetchNodeDefault: vi.fn(), + fetchPipelineNodeDefault: vi.fn(), +})) + +const mockUseNodesReadOnly = vi.mocked(useNodesReadOnly) +const mockUseNodeCrud = vi.mocked(useNodeCrud) +const mockUseVarList = vi.mocked(useVarList) +const mockUseOutputVarList = vi.mocked(useOutputVarList) +const mockUseStore = vi.mocked(useStore) +const mockFetchNodeDefault = vi.mocked(fetchNodeDefault) +const mockFetchPipelineNodeDefault = vi.mocked(fetchPipelineNodeDefault) + +const createVariable = (variable: string, valueType: VarType = VarType.string): Variable => ({ + variable, + value_selector: ['start', variable], + value_type: valueType, +}) + +const createOutputs = (name = 'result', type: VarType = VarType.string): OutputVar => ({ + [name]: { + type, + children: null, + }, +}) + +const createData = (overrides: Partial = {}): CodeNodeType => ({ + title: 'Code', + desc: '', + type: BlockEnum.Code, + code_language: CodeLanguage.javascript, + code: 'function main({ foo }) { return { result: foo } }', + variables: [createVariable('foo')], + outputs: createOutputs(), + ...overrides, +}) + +describe('code/use-config', () => { + const mockSetInputs = vi.fn() + const mockHandleVarListChange = vi.fn() + const mockHandleAddVariable = vi.fn() + const mockHandleVarsChange = vi.fn() + const mockHandleAddOutputVariable = vi.fn() + const mockHandleRemoveVariable = vi.fn() + const mockHideRemoveVarConfirm = vi.fn() + const mockOnRemoveVarConfirm = vi.fn() + + let workflowStoreState: { + appId?: string + pipelineId?: string + nodesDefaultConfigs?: Record + } + let currentInputs: CodeNodeType + let javaScriptConfig: CodeNodeType + let pythonConfig: CodeNodeType + + beforeEach(() => { + vi.clearAllMocks() + + javaScriptConfig = createData({ + code_language: CodeLanguage.javascript, + code: 'function main({ query }) { return { result: query } }', + variables: [createVariable('query')], + outputs: createOutputs('result'), + }) + pythonConfig = createData({ + code_language: CodeLanguage.python3, + code: 'def main(name: str):\n return {"result": name}', + variables: [createVariable('name')], + outputs: createOutputs('result'), + }) + currentInputs = createData() + workflowStoreState = { + appId: undefined, + pipelineId: undefined, + nodesDefaultConfigs: { + [BlockEnum.Code]: createData({ + code_language: CodeLanguage.javascript, + code: 'function main() { return { default_result: "" } }', + variables: [], + outputs: createOutputs('default_result'), + }), + }, + } + + mockUseNodesReadOnly.mockReturnValue({ + nodesReadOnly: false, + getNodesReadOnly: () => false, + }) + mockUseNodeCrud.mockImplementation(() => ({ + inputs: currentInputs, + setInputs: mockSetInputs, + })) + mockUseVarList.mockReturnValue({ + handleVarListChange: mockHandleVarListChange, + handleAddVariable: mockHandleAddVariable, + } as ReturnType) + mockUseOutputVarList.mockReturnValue({ + handleVarsChange: mockHandleVarsChange, + handleAddVariable: mockHandleAddOutputVariable, + handleRemoveVariable: mockHandleRemoveVariable, + isShowRemoveVarConfirm: false, + hideRemoveVarConfirm: mockHideRemoveVarConfirm, + onRemoveVarConfirm: mockOnRemoveVarConfirm, + } as ReturnType) + mockUseStore.mockImplementation(selector => selector(workflowStoreState as never)) + mockFetchNodeDefault.mockResolvedValue({ config: javaScriptConfig } as never) + mockFetchPipelineNodeDefault.mockResolvedValue({ config: javaScriptConfig } as never) + mockFetchNodeDefault + .mockResolvedValueOnce({ config: javaScriptConfig } as never) + .mockResolvedValueOnce({ config: pythonConfig } as never) + mockFetchPipelineNodeDefault + .mockResolvedValueOnce({ config: javaScriptConfig } as never) + .mockResolvedValueOnce({ config: pythonConfig } as never) + }) + + it('hydrates node defaults when the code payload is empty and syncs output key order', async () => { + currentInputs = createData({ + code: '', + variables: [], + outputs: {}, + }) + + const { result } = renderHook(() => useConfig('code-node', currentInputs)) + + await waitFor(() => { + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: workflowStoreState.nodesDefaultConfigs?.[BlockEnum.Code]?.code, + outputs: workflowStoreState.nodesDefaultConfigs?.[BlockEnum.Code]?.outputs, + })) + }) + + expect(result.current.handleVarListChange).toBe(mockHandleVarListChange) + expect(result.current.handleAddVariable).toBe(mockHandleAddVariable) + expect(result.current.handleVarsChange).toBe(mockHandleVarsChange) + expect(result.current.handleAddOutputVariable).toBe(mockHandleAddOutputVariable) + expect(result.current.handleRemoveVariable).toBe(mockHandleRemoveVariable) + expect(result.current.hideRemoveVarConfirm).toBe(mockHideRemoveVarConfirm) + expect(result.current.onRemoveVarConfirm).toBe(mockOnRemoveVarConfirm) + expect(result.current.outputKeyOrders).toEqual(['default_result']) + expect(result.current.filterVar({ type: VarType.file } as Var)).toBe(true) + expect(result.current.filterVar({ type: VarType.secret } as Var)).toBe(true) + }) + + it('fetches app and pipeline defaults, switches language, and updates code and output vars together', async () => { + workflowStoreState.appId = 'app-1' + workflowStoreState.pipelineId = 'pipeline-1' + + const { result } = renderHook(() => useConfig('code-node', currentInputs)) + + await waitFor(() => { + expect(mockFetchNodeDefault).toHaveBeenCalledWith('app-1', BlockEnum.Code, { code_language: CodeLanguage.javascript }) + expect(mockFetchNodeDefault).toHaveBeenCalledWith('app-1', BlockEnum.Code, { code_language: CodeLanguage.python3 }) + expect(mockFetchPipelineNodeDefault).toHaveBeenCalledWith('pipeline-1', BlockEnum.Code, { code_language: CodeLanguage.javascript }) + expect(mockFetchPipelineNodeDefault).toHaveBeenCalledWith('pipeline-1', BlockEnum.Code, { code_language: CodeLanguage.python3 }) + }) + + mockSetInputs.mockClear() + + act(() => { + result.current.handleCodeLanguageChange(CodeLanguage.python3) + result.current.handleCodeChange('function main({ bar }) { return { result: bar } }') + result.current.handleCodeAndVarsChange( + 'function main({ amount }) { return { total: amount } }', + [createVariable('amount', VarType.number)], + createOutputs('total', VarType.number), + ) + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code_language: CodeLanguage.python3, + code: pythonConfig.code, + variables: pythonConfig.variables, + outputs: pythonConfig.outputs, + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: 'function main({ bar }) { return { result: bar } }', + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: 'function main({ amount }) { return { total: amount } }', + variables: [expect.objectContaining({ variable: 'amount' })], + outputs: createOutputs('total', VarType.number), + })) + expect(result.current.outputKeyOrders).toEqual(['total']) + }) + + it('syncs javascript and python function signatures and keeps json code unchanged', () => { + currentInputs = createData({ + code_language: CodeLanguage.javascript, + code: 'function main() { return { result: "" } }', + variables: [createVariable('foo'), createVariable('bar')], + }) + + const { result, rerender } = renderHook(() => useConfig('code-node', currentInputs)) + + act(() => { + result.current.handleSyncFunctionSignature() + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: 'function main({foo, bar}) { return { result: "" } }', + })) + + mockSetInputs.mockClear() + currentInputs = createData({ + code_language: CodeLanguage.python3, + code: 'def main():\n return {"result": ""}', + variables: [ + createVariable('text', VarType.string), + createVariable('score', VarType.number), + createVariable('payload', VarType.object), + createVariable('items', VarType.array), + createVariable('numbers', VarType.arrayNumber), + createVariable('names', VarType.arrayString), + createVariable('records', VarType.arrayObject), + ], + }) + rerender() + + act(() => { + result.current.handleSyncFunctionSignature() + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: 'def main(text: str, score: float, payload: dict, items: list, numbers: list[float], names: list[str], records: list[dict]):\n return {"result": ""}', + })) + + mockSetInputs.mockClear() + currentInputs = createData({ + code_language: CodeLanguage.json, + code: '{"result": true}', + }) + rerender() + + act(() => { + result.current.handleSyncFunctionSignature() + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: '{"result": true}', + })) + }) + + it('keeps language changes local when no fetched default exists and preserves existing output order', async () => { + currentInputs = createData({ + outputs: { + summary: { + type: VarType.string, + children: null, + }, + count: { + type: VarType.number, + children: null, + }, + }, + }) + workflowStoreState.appId = undefined + workflowStoreState.pipelineId = undefined + + const { result } = renderHook(() => useConfig('code-node', currentInputs)) + + await waitFor(() => { + expect(result.current.outputKeyOrders).toEqual(['summary', 'count']) + }) + + mockSetInputs.mockClear() + + act(() => { + result.current.handleCodeLanguageChange(CodeLanguage.python3) + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code_language: CodeLanguage.python3, + code: currentInputs.code, + variables: currentInputs.variables, + outputs: currentInputs.outputs, + })) + }) +}) diff --git a/web/app/components/workflow/nodes/code/__tests__/use-single-run-form-params.spec.ts b/web/app/components/workflow/nodes/code/__tests__/use-single-run-form-params.spec.ts new file mode 100644 index 0000000000..39e9d8139a --- /dev/null +++ b/web/app/components/workflow/nodes/code/__tests__/use-single-run-form-params.spec.ts @@ -0,0 +1,80 @@ +import type { CodeNodeType } from '../types' +import { renderHook } from '@testing-library/react' +import { BlockEnum, InputVarType, VarType } from '@/app/components/workflow/types' +import useNodeCrud from '../../_base/hooks/use-node-crud' +import { CodeLanguage } from '../types' +import useSingleRunFormParams from '../use-single-run-form-params' + +vi.mock('../../_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +const mockUseNodeCrud = vi.mocked(useNodeCrud) + +const createData = (overrides: Partial = {}): CodeNodeType => ({ + title: 'Code', + desc: '', + type: BlockEnum.Code, + code_language: CodeLanguage.javascript, + code: 'function main({ amount }) { return { result: amount } }', + variables: [{ + variable: 'amount', + value_selector: ['start', 'amount'], + value_type: VarType.number, + }], + outputs: { + result: { + type: VarType.number, + children: null, + }, + }, + ...overrides, +}) + +describe('code/use-single-run-form-params', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodeCrud.mockReturnValue({ + inputs: createData(), + setInputs: vi.fn(), + } as unknown as ReturnType) + }) + + it('builds a single form, updates run input values, and exposes dependent vars', () => { + const setRunInputData = vi.fn() + + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'code-node', + payload: createData(), + runInputData: { amount: 1 }, + runInputDataRef: { current: { amount: 1 } }, + getInputVars: () => [], + setRunInputData, + toVarInputs: variables => variables.map(variable => ({ + type: InputVarType.number, + label: variable.variable, + variable: variable.variable, + required: false, + })), + })) + + expect(result.current.forms).toEqual([{ + inputs: [{ + type: InputVarType.number, + label: 'amount', + variable: 'amount', + required: false, + }], + values: { amount: 1 }, + onChange: expect.any(Function), + }]) + + result.current.forms[0]?.onChange({ amount: 3 }) + + expect(setRunInputData).toHaveBeenCalledWith({ amount: 3 }) + expect(result.current.getDependentVars()).toEqual([['start', 'amount']]) + expect(result.current.getDependentVar('amount')).toEqual(['start', 'amount']) + expect(result.current.getDependentVar('missing')).toBeUndefined() + }) +}) diff --git a/web/app/components/workflow/nodes/data-source/__tests__/before-run-form.spec.tsx b/web/app/components/workflow/nodes/data-source/__tests__/before-run-form.spec.tsx new file mode 100644 index 0000000000..c12ec212bf --- /dev/null +++ b/web/app/components/workflow/nodes/data-source/__tests__/before-run-form.spec.tsx @@ -0,0 +1,205 @@ +import type { ReactNode } from 'react' +import type { CustomRunFormProps, DataSourceNodeType } from '../types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { DatasourceType } from '@/models/pipeline' +import { FlowType } from '@/types/common' +import { BlockEnum } from '../../../types' +import BeforeRunForm from '../before-run-form' +import useBeforeRunForm from '../hooks/use-before-run-form' + +const mockUseDataSourceStore = vi.hoisted(() => vi.fn()) +const mockSetCurrentCredentialId = vi.hoisted(() => vi.fn()) +const mockClearOnlineDocumentData = vi.hoisted(() => vi.fn()) +const mockClearWebsiteCrawlData = vi.hoisted(() => vi.fn()) +const mockClearOnlineDriveData = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/store', () => ({ + useDataSourceStore: () => mockUseDataSourceStore(), +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/store/provider', () => ({ + __esModule: true, + default: ({ children }: { children: ReactNode }) => <>{children}, +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/local-file', () => ({ + __esModule: true, + default: ({ allowedExtensions }: { allowedExtensions: string[] }) =>
{allowedExtensions.join(',')}
, +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/online-documents', () => ({ + __esModule: true, + default: ({ onCredentialChange }: { onCredentialChange: (credentialId: string) => void }) => ( + + ), +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl', () => ({ + __esModule: true, + default: ({ onCredentialChange }: { onCredentialChange: (credentialId: string) => void }) => ( + + ), +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/online-drive', () => ({ + __esModule: true, + default: ({ onCredentialChange }: { onCredentialChange: (credentialId: string) => void }) => ( + + ), +})) + +vi.mock('@/app/components/rag-pipeline/components/panel/test-run/preparation/hooks', () => ({ + useOnlineDocument: () => ({ clearOnlineDocumentData: mockClearOnlineDocumentData }), + useWebsiteCrawl: () => ({ clearWebsiteCrawlData: mockClearWebsiteCrawlData }), + useOnlineDrive: () => ({ clearOnlineDriveData: mockClearOnlineDriveData }), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/before-run-form/panel-wrap', () => ({ + __esModule: true, + default: ({ nodeName, onHide, children }: { nodeName: string, onHide: () => void, children: ReactNode }) => ( +
+
{nodeName}
+ + {children} +
+ ), +})) + +vi.mock('../hooks/use-before-run-form', () => ({ + __esModule: true, + default: vi.fn(), +})) + +const mockUseBeforeRunForm = vi.mocked(useBeforeRunForm) + +const createData = (overrides: Partial = {}): DataSourceNodeType => ({ + title: 'Datasource', + desc: '', + type: BlockEnum.DataSource, + plugin_id: 'plugin-id', + provider_type: DatasourceType.localFile, + provider_name: 'file', + datasource_name: 'local-file', + datasource_label: 'Local File', + datasource_parameters: {}, + datasource_configurations: {}, + fileExtensions: ['pdf', 'md'], + ...overrides, +}) + +const createProps = (overrides: Partial = {}): CustomRunFormProps => ({ + nodeId: 'data-source-node', + flowId: 'flow-id', + flowType: FlowType.ragPipeline, + payload: createData(), + setRunResult: vi.fn(), + setIsRunAfterSingleRun: vi.fn(), + isPaused: false, + isRunAfterSingleRun: false, + onSuccess: vi.fn(), + onCancel: vi.fn(), + appendNodeInspectVars: vi.fn(), + ...overrides, +}) + +describe('data-source/before-run-form', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseDataSourceStore.mockReturnValue({ + getState: () => ({ + setCurrentCredentialId: mockSetCurrentCredentialId, + }), + }) + mockUseBeforeRunForm.mockReturnValue({ + isPending: false, + handleRunWithSyncDraft: vi.fn(), + datasourceType: DatasourceType.localFile, + datasourceNodeData: createData(), + startRunBtnDisabled: false, + }) + }) + + it('renders the local-file preparation form and triggers run/cancel actions', async () => { + const user = userEvent.setup() + const onCancel = vi.fn() + const handleRunWithSyncDraft = vi.fn() + + mockUseBeforeRunForm.mockReturnValueOnce({ + isPending: false, + handleRunWithSyncDraft, + datasourceType: DatasourceType.localFile, + datasourceNodeData: createData(), + startRunBtnDisabled: false, + }) + + render() + + expect(screen.getByText('Datasource')).toBeInTheDocument() + expect(screen.getByText('pdf,md')).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + await user.click(screen.getByRole('button', { name: 'workflow.singleRun.startRun' })) + + expect(onCancel).toHaveBeenCalled() + expect(handleRunWithSyncDraft).toHaveBeenCalled() + }) + + it('clears stale online document data before switching credentials', async () => { + const user = userEvent.setup() + + mockUseBeforeRunForm.mockReturnValueOnce({ + isPending: false, + handleRunWithSyncDraft: vi.fn(), + datasourceType: DatasourceType.onlineDocument, + datasourceNodeData: createData({ provider_type: DatasourceType.onlineDocument }), + startRunBtnDisabled: true, + }) + + render() + + await user.click(screen.getByRole('button', { name: 'online-documents' })) + + expect(mockClearOnlineDocumentData).toHaveBeenCalled() + expect(mockSetCurrentCredentialId).toHaveBeenCalledWith('credential-doc') + expect(screen.getByRole('button', { name: 'workflow.singleRun.startRun' })).toBeDisabled() + }) + + it('clears website crawl data before switching credentials', async () => { + const user = userEvent.setup() + + mockUseBeforeRunForm.mockReturnValueOnce({ + isPending: false, + handleRunWithSyncDraft: vi.fn(), + datasourceType: DatasourceType.websiteCrawl, + datasourceNodeData: createData({ provider_type: DatasourceType.websiteCrawl }), + startRunBtnDisabled: false, + }) + + render() + + await user.click(screen.getByRole('button', { name: 'website-crawl' })) + + expect(mockClearWebsiteCrawlData).toHaveBeenCalled() + expect(mockSetCurrentCredentialId).toHaveBeenCalledWith('credential-site') + }) + + it('clears online drive data before switching credentials', async () => { + const user = userEvent.setup() + + mockUseBeforeRunForm.mockReturnValueOnce({ + isPending: false, + handleRunWithSyncDraft: vi.fn(), + datasourceType: DatasourceType.onlineDrive, + datasourceNodeData: createData({ provider_type: DatasourceType.onlineDrive }), + startRunBtnDisabled: false, + }) + + render() + + await user.click(screen.getByRole('button', { name: 'online-drive' })) + + expect(mockClearOnlineDriveData).toHaveBeenCalled() + expect(mockSetCurrentCredentialId).toHaveBeenCalledWith('credential-drive') + }) +}) diff --git a/web/app/components/workflow/nodes/data-source/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/data-source/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..8160da6502 --- /dev/null +++ b/web/app/components/workflow/nodes/data-source/__tests__/panel.spec.tsx @@ -0,0 +1,194 @@ +import type { ReactNode } from 'react' +import type { DataSourceNodeType } from '../types' +import type { NodePanelProps } from '@/app/components/workflow/types' +import { fireEvent, render, screen } from '@testing-library/react' +import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' +import { useNodesReadOnly } from '@/app/components/workflow/hooks' +import { useStore } from '@/app/components/workflow/store' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import useMatchSchemaType, { getMatchedSchemaType } from '../../_base/components/variable/use-match-schema-type' +import ToolForm from '../../tool/components/tool-form' +import { useConfig } from '../hooks/use-config' +import Panel from '../panel' + +const mockWrapStructuredVarItem = vi.hoisted(() => vi.fn((payload: unknown) => payload)) + +vi.mock('@/app/components/base/tag-input', () => ({ + __esModule: true, + default: ({ + items, + onChange, + placeholder, + }: { + items: string[] + onChange: (items: string[]) => void + placeholder?: string + }) => ( + + ), +})) + +vi.mock('@/app/components/tools/utils/to-form-schema', () => ({ + toolParametersToFormSchemas: vi.fn(), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: vi.fn(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: vi.fn(), +})) + +vi.mock('@/app/components/workflow/utils/tool', () => ({ + wrapStructuredVarItem: (payload: unknown) => mockWrapStructuredVarItem(payload), +})) + +vi.mock('../../_base/components/output-vars', () => ({ + __esModule: true, + default: ({ children }: { children: ReactNode }) =>
{children}
, + VarItem: ({ name, type }: { name: string, type: string }) =>
{`${name}:${type}`}
, +})) + +vi.mock('../../_base/components/variable/object-child-tree-panel/show', () => ({ + __esModule: true, + default: ({ payload }: { payload: { name: string } }) =>
{payload.name}
, +})) + +vi.mock('../../_base/components/variable/use-match-schema-type', () => ({ + __esModule: true, + default: vi.fn(), + getMatchedSchemaType: vi.fn(), +})) + +vi.mock('../../tool/components/tool-form', () => ({ + __esModule: true, + default: vi.fn(({ onChange, onManageInputField }: { onChange: (value: unknown) => void, onManageInputField?: () => void }) => ( +
+ + +
+ )), +})) + +vi.mock('../hooks/use-config', () => ({ + useConfig: vi.fn(), +})) + +const mockUseNodesReadOnly = vi.mocked(useNodesReadOnly) +const mockUseStore = vi.mocked(useStore) +const mockUseConfig = vi.mocked(useConfig) +const mockToolParametersToFormSchemas = vi.mocked(toolParametersToFormSchemas) +const mockUseMatchSchemaType = vi.mocked(useMatchSchemaType) +const mockGetMatchedSchemaType = vi.mocked(getMatchedSchemaType) +const mockToolForm = vi.mocked(ToolForm) + +const setShowInputFieldPanel = vi.fn() + +const createData = (overrides: Partial = {}): DataSourceNodeType => ({ + title: 'Datasource', + desc: '', + type: BlockEnum.DataSource, + plugin_id: 'plugin-1', + provider_type: 'remote', + provider_name: 'provider', + datasource_name: 'source-a', + datasource_label: 'Source A', + datasource_parameters: {}, + datasource_configurations: {}, + fileExtensions: ['pdf'], + ...overrides, +}) + +const panelProps = {} as NodePanelProps['panelProps'] + +describe('data-source/panel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false, getNodesReadOnly: () => false }) + mockUseStore.mockImplementation((selector) => { + const select = selector as (state: unknown) => unknown + return select({ + dataSourceList: [{ + plugin_id: 'plugin-1', + is_authorized: true, + tools: [{ + name: 'source-a', + parameters: [{ name: 'dataset' }], + }], + }], + pipelineId: 'pipeline-1', + setShowInputFieldPanel, + }) + }) + mockUseConfig.mockReturnValue({ + handleFileExtensionsChange: vi.fn(), + handleParametersChange: vi.fn(), + outputSchema: [], + hasObjectOutput: false, + }) + mockToolParametersToFormSchemas.mockReturnValue([{ name: 'dataset' }] as never) + mockUseMatchSchemaType.mockReturnValue({ schemaTypeDefinitions: {} } as ReturnType) + mockGetMatchedSchemaType.mockReturnValue('') + }) + + it('renders the authorized tool form path and forwards parameter changes', () => { + const handleParametersChange = vi.fn() + mockUseConfig.mockReturnValueOnce({ + handleFileExtensionsChange: vi.fn(), + handleParametersChange, + outputSchema: [{ + name: 'metadata', + value: { type: 'object' }, + }], + hasObjectOutput: true, + }) + mockGetMatchedSchemaType.mockReturnValueOnce('json') + + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'tool-form-change' })) + fireEvent.click(screen.getByRole('button', { name: 'manage-input-field' })) + + expect(handleParametersChange).toHaveBeenCalledWith({ dataset: 'docs' }) + expect(setShowInputFieldPanel).toHaveBeenCalledWith(true) + expect(mockToolForm).toHaveBeenCalledWith(expect.objectContaining({ + nodeId: 'data-source-node', + showManageInputField: true, + value: {}, + }), undefined) + expect(screen.getByText('metadata')).toBeInTheDocument() + }) + + it('renders the local-file path and updates supported file extensions', () => { + const handleFileExtensionsChange = vi.fn() + mockUseConfig.mockReturnValueOnce({ + handleFileExtensionsChange, + handleParametersChange: vi.fn(), + outputSchema: [], + hasObjectOutput: false, + }) + + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'workflow.nodes.dataSource.supportedFileFormatsPlaceholder' })) + + expect(handleFileExtensionsChange).toHaveBeenCalledWith(['pdf', 'txt']) + expect(screen.getByText(`datasource_type:${VarType.string}`)).toBeInTheDocument() + expect(screen.getByText(`file:${VarType.file}`)).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.branches.spec.tsx b/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.branches.spec.tsx new file mode 100644 index 0000000000..09172dd673 --- /dev/null +++ b/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.branches.spec.tsx @@ -0,0 +1,308 @@ +import type { CustomRunFormProps, DataSourceNodeType } from '../../types' +import type { NodeRunResult, VarInInspect } from '@/types/workflow' +import { act, renderHook } from '@testing-library/react' +import { useStoreApi } from 'reactflow' +import { useDataSourceStore, useDataSourceStoreWithSelector } from '@/app/components/datasets/documents/create-from-pipeline/data-source/store' +import { BlockEnum, NodeRunningStatus } from '@/app/components/workflow/types' +import { DatasourceType } from '@/models/pipeline' +import { useDatasourceSingleRun } from '@/service/use-pipeline' +import { useInvalidLastRun } from '@/service/use-workflow' +import { fetchNodeInspectVars } from '@/service/workflow' +import { FlowType } from '@/types/common' +import { useNodeDataUpdate, useNodesSyncDraft } from '../../../../hooks' +import useBeforeRunForm from '../use-before-run-form' + +type DataSourceStoreState = { + currentNodeIdRef: { current: string } + currentCredentialId: string + setCurrentCredentialId: (credentialId: string) => void + currentCredentialIdRef: { current: string } + localFileList: Array<{ + file: { + id: string + name: string + type: string + size: number + extension: string + mime_type: string + } + }> + onlineDocuments: Array> + websitePages: Array> + selectedFileIds: string[] + onlineDriveFileList: Array<{ id: string, type: string }> + bucket?: string +} + +type DatasourceSingleRunOptions = { + onError?: () => void + onSettled?: (data?: NodeRunResult) => void +} + +const mockHandleNodeDataUpdate = vi.hoisted(() => vi.fn()) +const mockHandleSyncWorkflowDraft = vi.hoisted(() => vi.fn()) +const mockMutateAsync = vi.hoisted(() => vi.fn()) +const mockInvalidLastRun = vi.hoisted(() => vi.fn()) +const mockFetchNodeInspectVars = vi.hoisted(() => vi.fn()) +const mockUseDataSourceStore = vi.hoisted(() => vi.fn()) +const mockUseDataSourceStoreWithSelector = vi.hoisted(() => vi.fn()) + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useStoreApi: vi.fn(), + } +}) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodeDataUpdate: vi.fn(), + useNodesSyncDraft: vi.fn(), +})) + +vi.mock('@/service/use-pipeline', () => ({ + useDatasourceSingleRun: vi.fn(), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidLastRun: vi.fn(), +})) + +vi.mock('@/service/workflow', () => ({ + fetchNodeInspectVars: vi.fn(), +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/store', () => ({ + useDataSourceStore: vi.fn(), + useDataSourceStoreWithSelector: vi.fn(), +})) + +const mockUseStoreApi = vi.mocked(useStoreApi) +const mockUseNodeDataUpdateHook = vi.mocked(useNodeDataUpdate) +const mockUseNodesSyncDraftHook = vi.mocked(useNodesSyncDraft) +const mockUseDatasourceSingleRunHook = vi.mocked(useDatasourceSingleRun) +const mockUseInvalidLastRunHook = vi.mocked(useInvalidLastRun) +const mockFetchNodeInspectVarsFn = vi.mocked(fetchNodeInspectVars) +const mockUseDataSourceStoreHook = vi.mocked(useDataSourceStore) +const mockUseDataSourceStoreWithSelectorHook = vi.mocked(useDataSourceStoreWithSelector) + +const createData = (overrides: Partial = {}): DataSourceNodeType => ({ + title: 'Datasource', + desc: '', + type: BlockEnum.DataSource, + plugin_id: 'plugin-id', + provider_type: DatasourceType.localFile, + provider_name: 'provider', + datasource_name: 'local-file', + datasource_label: 'Local File', + datasource_parameters: {}, + datasource_configurations: {}, + fileExtensions: ['pdf'], + ...overrides, +}) + +const createProps = (overrides: Partial = {}): CustomRunFormProps => ({ + nodeId: 'data-source-node', + flowId: 'flow-id', + flowType: FlowType.ragPipeline, + payload: createData(), + setRunResult: vi.fn(), + setIsRunAfterSingleRun: vi.fn(), + isPaused: false, + isRunAfterSingleRun: false, + onSuccess: vi.fn(), + onCancel: vi.fn(), + appendNodeInspectVars: vi.fn(), + ...overrides, +}) + +describe('data-source/hooks/use-before-run-form branches', () => { + let dataSourceStoreState: DataSourceStoreState + + beforeEach(() => { + vi.clearAllMocks() + + dataSourceStoreState = { + currentNodeIdRef: { current: 'data-source-node' }, + currentCredentialId: 'credential-1', + setCurrentCredentialId: vi.fn(), + currentCredentialIdRef: { current: 'credential-1' }, + localFileList: [], + onlineDocuments: [], + websitePages: [], + selectedFileIds: [], + onlineDriveFileList: [], + bucket: 'drive-bucket', + } + + mockUseStoreApi.mockReturnValue({ + getState: () => ({ + getNodes: () => [{ id: 'data-source-node', data: { title: 'Datasource' } }], + }), + } as ReturnType) + + mockUseNodeDataUpdateHook.mockReturnValue({ + handleNodeDataUpdate: mockHandleNodeDataUpdate, + handleNodeDataUpdateWithSyncDraft: vi.fn(), + } as ReturnType) + mockUseNodesSyncDraftHook.mockReturnValue({ + handleSyncWorkflowDraft: (...args: unknown[]) => { + mockHandleSyncWorkflowDraft(...args) + const callbacks = args[2] as { onSuccess?: () => void } | undefined + callbacks?.onSuccess?.() + }, + } as ReturnType) + mockUseDatasourceSingleRunHook.mockReturnValue({ + mutateAsync: (...args: unknown[]) => mockMutateAsync(...args), + isPending: false, + } as ReturnType) + mockUseInvalidLastRunHook.mockReturnValue(mockInvalidLastRun) + mockFetchNodeInspectVarsFn.mockImplementation((...args: unknown[]) => mockFetchNodeInspectVars(...args)) + mockUseDataSourceStoreHook.mockImplementation(() => mockUseDataSourceStore()) + mockUseDataSourceStoreWithSelectorHook.mockImplementation(selector => + mockUseDataSourceStoreWithSelector(selector as unknown as (state: DataSourceStoreState) => unknown)) + + mockUseDataSourceStore.mockImplementation(() => ({ + getState: () => dataSourceStoreState, + })) + mockUseDataSourceStoreWithSelector.mockImplementation((selector: (state: DataSourceStoreState) => unknown) => + selector(dataSourceStoreState)) + mockFetchNodeInspectVars.mockResolvedValue([{ name: 'metadata' }] as VarInInspect[]) + }) + + it('derives disabled states for online documents and website crawl sources', () => { + const { result, rerender } = renderHook( + ({ payload }) => useBeforeRunForm(createProps({ payload })), + { + initialProps: { + payload: createData({ provider_type: DatasourceType.onlineDocument }), + }, + }, + ) + + expect(result.current.startRunBtnDisabled).toBe(true) + + dataSourceStoreState.onlineDocuments = [{ + workspace_id: 'workspace-1', + id: 'doc-1', + title: 'Document', + }] + rerender({ payload: createData({ provider_type: DatasourceType.onlineDocument }) }) + expect(result.current.startRunBtnDisabled).toBe(false) + + rerender({ payload: createData({ provider_type: DatasourceType.websiteCrawl }) }) + expect(result.current.startRunBtnDisabled).toBe(true) + + dataSourceStoreState.websitePages = [{ url: 'https://example.com' }] + rerender({ payload: createData({ provider_type: DatasourceType.websiteCrawl }) }) + expect(result.current.startRunBtnDisabled).toBe(false) + }) + + it('returns the settled run result directly when chained single-run execution should stop', async () => { + dataSourceStoreState.localFileList = [{ + file: { + id: 'file-1', + name: 'doc.pdf', + type: 'document', + size: 12, + extension: 'pdf', + mime_type: 'application/pdf', + }, + }] + + mockMutateAsync.mockImplementation((_payload: unknown, options: DatasourceSingleRunOptions) => { + options.onSettled?.({ status: NodeRunningStatus.Succeeded } as NodeRunResult) + return Promise.resolve(undefined) + }) + + const props = createProps({ + isRunAfterSingleRun: true, + payload: createData({ + _singleRunningStatus: NodeRunningStatus.Running, + } as Partial), + }) + const { result } = renderHook(() => useBeforeRunForm(props)) + + await act(async () => { + result.current.handleRunWithSyncDraft() + await Promise.resolve() + }) + + expect(props.setRunResult).toHaveBeenCalledWith({ status: NodeRunningStatus.Succeeded }) + expect(mockFetchNodeInspectVars).not.toHaveBeenCalled() + expect(props.onSuccess).not.toHaveBeenCalled() + }) + + it('builds online document datasource info before running', async () => { + dataSourceStoreState.onlineDocuments = [{ + workspace_id: 'workspace-1', + id: 'doc-1', + title: 'Document', + url: 'https://example.com/doc', + }] + + mockMutateAsync.mockImplementation((payload: unknown, options: DatasourceSingleRunOptions) => { + options.onSettled?.({ status: NodeRunningStatus.Succeeded } as NodeRunResult) + return Promise.resolve(payload) + }) + + const { result } = renderHook(() => useBeforeRunForm(createProps({ + payload: createData({ provider_type: DatasourceType.onlineDocument }), + }))) + + await act(async () => { + result.current.handleRunWithSyncDraft() + await Promise.resolve() + }) + + expect(mockMutateAsync).toHaveBeenCalledWith(expect.objectContaining({ + datasource_type: DatasourceType.onlineDocument, + datasource_info: { + workspace_id: 'workspace-1', + page: { + id: 'doc-1', + title: 'Document', + url: 'https://example.com/doc', + }, + credential_id: 'credential-1', + }, + }), expect.any(Object)) + }) + + it('builds website crawl datasource info and skips the failure update while paused', async () => { + dataSourceStoreState.websitePages = [{ + url: 'https://example.com', + title: 'Example', + }] + + mockMutateAsync.mockImplementation((payload: unknown, options: DatasourceSingleRunOptions) => { + options.onError?.() + return Promise.resolve(payload) + }) + + const { result } = renderHook(() => useBeforeRunForm(createProps({ + isPaused: true, + payload: createData({ provider_type: DatasourceType.websiteCrawl }), + }))) + + await act(async () => { + result.current.handleRunWithSyncDraft() + await Promise.resolve() + }) + + expect(mockMutateAsync).toHaveBeenCalledWith(expect.objectContaining({ + datasource_type: DatasourceType.websiteCrawl, + datasource_info: { + url: 'https://example.com', + title: 'Example', + credential_id: 'credential-1', + }, + }), expect.any(Object)) + expect(mockInvalidLastRun).toHaveBeenCalled() + expect(mockHandleNodeDataUpdate).not.toHaveBeenCalledWith(expect.objectContaining({ + data: expect.objectContaining({ + _singleRunningStatus: NodeRunningStatus.Failed, + }), + })) + }) +}) diff --git a/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.spec.tsx b/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.spec.tsx new file mode 100644 index 0000000000..b4e79b3334 --- /dev/null +++ b/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.spec.tsx @@ -0,0 +1,307 @@ +import type { CustomRunFormProps, DataSourceNodeType } from '../../types' +import type { NodeRunResult, VarInInspect } from '@/types/workflow' +import { act, renderHook } from '@testing-library/react' +import { useStoreApi } from 'reactflow' +import { useDataSourceStore, useDataSourceStoreWithSelector } from '@/app/components/datasets/documents/create-from-pipeline/data-source/store' +import { BlockEnum, NodeRunningStatus } from '@/app/components/workflow/types' +import { DatasourceType } from '@/models/pipeline' +import { useDatasourceSingleRun } from '@/service/use-pipeline' +import { useInvalidLastRun } from '@/service/use-workflow' +import { fetchNodeInspectVars } from '@/service/workflow' +import { TransferMethod } from '@/types/app' +import { FlowType } from '@/types/common' +import { useNodeDataUpdate, useNodesSyncDraft } from '../../../../hooks' +import useBeforeRunForm from '../use-before-run-form' + +type DataSourceStoreState = { + currentNodeIdRef: { current: string } + currentCredentialId: string + setCurrentCredentialId: (credentialId: string) => void + currentCredentialIdRef: { current: string } + localFileList: Array<{ + file: { + id: string + name: string + type: string + size: number + extension: string + mime_type: string + } + }> + onlineDocuments: Array> + websitePages: Array> + selectedFileIds: string[] + onlineDriveFileList: Array<{ id: string, type: string }> + bucket?: string +} + +type DatasourceSingleRunOptions = { + onError?: () => void + onSettled?: (data?: NodeRunResult) => void +} + +const mockHandleNodeDataUpdate = vi.hoisted(() => vi.fn()) +const mockHandleSyncWorkflowDraft = vi.hoisted(() => vi.fn()) +const mockMutateAsync = vi.hoisted(() => vi.fn()) +const mockInvalidLastRun = vi.hoisted(() => vi.fn()) +const mockFetchNodeInspectVars = vi.hoisted(() => vi.fn()) +const mockUseDataSourceStore = vi.hoisted(() => vi.fn()) +const mockUseDataSourceStoreWithSelector = vi.hoisted(() => vi.fn()) + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useStoreApi: vi.fn(), + } +}) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodeDataUpdate: vi.fn(), + useNodesSyncDraft: vi.fn(), +})) + +vi.mock('@/service/use-pipeline', () => ({ + useDatasourceSingleRun: vi.fn(), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidLastRun: vi.fn(), +})) + +vi.mock('@/service/workflow', () => ({ + fetchNodeInspectVars: vi.fn(), +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/store', () => ({ + useDataSourceStore: vi.fn(), + useDataSourceStoreWithSelector: vi.fn(), +})) + +const mockUseStoreApi = vi.mocked(useStoreApi) +const mockUseNodeDataUpdateHook = vi.mocked(useNodeDataUpdate) +const mockUseNodesSyncDraftHook = vi.mocked(useNodesSyncDraft) +const mockUseDatasourceSingleRunHook = vi.mocked(useDatasourceSingleRun) +const mockUseInvalidLastRunHook = vi.mocked(useInvalidLastRun) +const mockFetchNodeInspectVarsFn = vi.mocked(fetchNodeInspectVars) +const mockUseDataSourceStoreHook = vi.mocked(useDataSourceStore) +const mockUseDataSourceStoreWithSelectorHook = vi.mocked(useDataSourceStoreWithSelector) + +const createData = (overrides: Partial = {}): DataSourceNodeType => ({ + title: 'Datasource', + desc: '', + type: BlockEnum.DataSource, + plugin_id: 'plugin-id', + provider_type: DatasourceType.localFile, + provider_name: 'provider', + datasource_name: 'local-file', + datasource_label: 'Local File', + datasource_parameters: {}, + datasource_configurations: {}, + fileExtensions: ['pdf'], + ...overrides, +}) + +const createProps = (overrides: Partial = {}): CustomRunFormProps => ({ + nodeId: 'data-source-node', + flowId: 'flow-id', + flowType: FlowType.ragPipeline, + payload: createData(), + setRunResult: vi.fn(), + setIsRunAfterSingleRun: vi.fn(), + isPaused: false, + isRunAfterSingleRun: false, + onSuccess: vi.fn(), + onCancel: vi.fn(), + appendNodeInspectVars: vi.fn(), + ...overrides, +}) + +describe('data-source/hooks/use-before-run-form', () => { + let dataSourceStoreState: DataSourceStoreState + + beforeEach(() => { + vi.clearAllMocks() + + dataSourceStoreState = { + currentNodeIdRef: { current: 'data-source-node' }, + currentCredentialId: 'credential-1', + setCurrentCredentialId: vi.fn(), + currentCredentialIdRef: { current: 'credential-1' }, + localFileList: [], + onlineDocuments: [], + websitePages: [], + selectedFileIds: [], + onlineDriveFileList: [], + bucket: 'drive-bucket', + } + + mockUseStoreApi.mockReturnValue({ + getState: () => ({ + getNodes: () => [ + { + id: 'data-source-node', + data: { + title: 'Datasource', + }, + }, + ], + }), + } as ReturnType) + + mockUseNodeDataUpdateHook.mockReturnValue({ + handleNodeDataUpdate: mockHandleNodeDataUpdate, + handleNodeDataUpdateWithSyncDraft: vi.fn(), + } as ReturnType) + mockUseNodesSyncDraftHook.mockReturnValue({ + handleSyncWorkflowDraft: (...args: unknown[]) => { + mockHandleSyncWorkflowDraft(...args) + const callbacks = args[2] as { onSuccess?: () => void } | undefined + callbacks?.onSuccess?.() + }, + } as ReturnType) + mockUseDatasourceSingleRunHook.mockReturnValue({ + mutateAsync: (...args: unknown[]) => mockMutateAsync(...args), + isPending: false, + } as ReturnType) + mockUseInvalidLastRunHook.mockReturnValue(mockInvalidLastRun) + mockFetchNodeInspectVarsFn.mockImplementation((...args: unknown[]) => mockFetchNodeInspectVars(...args)) + mockUseDataSourceStoreHook.mockImplementation(() => mockUseDataSourceStore()) + mockUseDataSourceStoreWithSelectorHook.mockImplementation(selector => + mockUseDataSourceStoreWithSelector(selector as unknown as (state: DataSourceStoreState) => unknown)) + + mockUseDataSourceStore.mockImplementation(() => ({ + getState: () => dataSourceStoreState, + })) + mockUseDataSourceStoreWithSelector.mockImplementation((selector: (state: DataSourceStoreState) => unknown) => + selector(dataSourceStoreState)) + mockFetchNodeInspectVars.mockResolvedValue([{ name: 'metadata' }] as VarInInspect[]) + }) + + it('derives the run button disabled state from the selected datasource payload', () => { + const { result, rerender } = renderHook( + ({ payload }) => useBeforeRunForm(createProps({ payload })), + { + initialProps: { + payload: createData(), + }, + }, + ) + + expect(result.current.startRunBtnDisabled).toBe(true) + + dataSourceStoreState.localFileList = [{ + file: { + id: 'file-1', + name: 'doc.pdf', + type: 'document', + size: 12, + extension: 'pdf', + mime_type: 'application/pdf', + }, + }] + rerender({ payload: createData() }) + expect(result.current.startRunBtnDisabled).toBe(false) + + dataSourceStoreState.selectedFileIds = [] + rerender({ + payload: createData({ + provider_type: DatasourceType.onlineDrive, + }), + }) + expect(result.current.startRunBtnDisabled).toBe(true) + }) + + it('syncs the draft, runs the datasource, and appends inspect vars on success', async () => { + dataSourceStoreState.localFileList = [{ + file: { + id: 'file-1', + name: 'doc.pdf', + type: 'document', + size: 12, + extension: 'pdf', + mime_type: 'application/pdf', + }, + }] + + mockMutateAsync.mockImplementation((payload: unknown, options: DatasourceSingleRunOptions) => { + options.onSettled?.({ status: NodeRunningStatus.Succeeded } as NodeRunResult) + return Promise.resolve(payload) + }) + + const props = createProps() + const { result } = renderHook(() => useBeforeRunForm(props)) + + await act(async () => { + result.current.handleRunWithSyncDraft() + await Promise.resolve() + }) + + expect(props.setIsRunAfterSingleRun).toHaveBeenCalledWith(true) + expect(mockHandleNodeDataUpdate).toHaveBeenNthCalledWith(1, { + id: 'data-source-node', + data: expect.objectContaining({ + _singleRunningStatus: NodeRunningStatus.Running, + }), + }) + expect(mockMutateAsync).toHaveBeenCalledWith(expect.objectContaining({ + pipeline_id: 'flow-id', + start_node_id: 'data-source-node', + datasource_type: DatasourceType.localFile, + datasource_info: expect.objectContaining({ + related_id: 'file-1', + transfer_method: TransferMethod.local_file, + }), + }), expect.any(Object)) + expect(mockFetchNodeInspectVars).toHaveBeenCalledWith(FlowType.ragPipeline, 'flow-id', 'data-source-node') + expect(props.appendNodeInspectVars).toHaveBeenCalledWith('data-source-node', [{ name: 'metadata' }], [ + { + id: 'data-source-node', + data: { + title: 'Datasource', + }, + }, + ]) + expect(props.onSuccess).toHaveBeenCalled() + expect(mockHandleNodeDataUpdate).toHaveBeenLastCalledWith({ + id: 'data-source-node', + data: expect.objectContaining({ + _isSingleRun: false, + _singleRunningStatus: NodeRunningStatus.Succeeded, + }), + }) + }) + + it('marks the last run invalid and updates the node to failed when the single run errors', async () => { + dataSourceStoreState.selectedFileIds = ['drive-file-1'] + dataSourceStoreState.onlineDriveFileList = [{ + id: 'drive-file-1', + type: 'file', + }] + + mockMutateAsync.mockImplementation((_payload: unknown, options: DatasourceSingleRunOptions) => { + options.onError?.() + return Promise.resolve(undefined) + }) + + const { result } = renderHook(() => useBeforeRunForm(createProps({ + payload: createData({ + provider_type: DatasourceType.onlineDrive, + }), + }))) + + await act(async () => { + result.current.handleRunWithSyncDraft() + await Promise.resolve() + }) + + expect(mockInvalidLastRun).toHaveBeenCalled() + expect(mockHandleNodeDataUpdate).toHaveBeenLastCalledWith({ + id: 'data-source-node', + data: expect.objectContaining({ + _isSingleRun: false, + _singleRunningStatus: NodeRunningStatus.Failed, + }), + }) + }) +}) diff --git a/web/app/components/workflow/nodes/document-extractor/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/document-extractor/__tests__/node.spec.tsx new file mode 100644 index 0000000000..2044d7e6b9 --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/__tests__/node.spec.tsx @@ -0,0 +1,74 @@ +import type { DocExtractorNodeType } from '../types' +import { render, screen } from '@testing-library/react' +import { useNodes } from 'reactflow' +import { BlockEnum } from '@/app/components/workflow/types' +import Node from '../node' + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useNodes: vi.fn(), + } +}) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/variable-label', () => ({ + VariableLabelInNode: ({ + variables, + nodeTitle, + nodeType, + }: { + variables: string[] + nodeTitle?: string + nodeType?: BlockEnum + }) =>
{`${nodeTitle}:${nodeType}:${variables.join('.')}`}
, +})) + +const mockUseNodes = vi.mocked(useNodes) + +const createData = (overrides: Partial = {}): DocExtractorNodeType => ({ + title: 'Document Extractor', + desc: '', + type: BlockEnum.DocExtractor, + variable_selector: ['node-1', 'files'], + is_array_file: false, + ...overrides, +}) + +describe('document-extractor/node', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodes.mockReturnValue([ + { + id: 'node-1', + data: { + title: 'Input Files', + type: BlockEnum.Start, + }, + }, + ] as ReturnType) + }) + + it('renders nothing when no input variable is configured', () => { + const { container } = render( + , + ) + + expect(container).toBeEmptyDOMElement() + }) + + it('renders the selected input variable label', () => { + render( + , + ) + + expect(screen.getByText('workflow.nodes.docExtractor.inputVar')).toBeInTheDocument() + expect(screen.getByText('Input Files:start:node-1.files')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/document-extractor/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/document-extractor/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..06512f94c6 --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/__tests__/panel.spec.tsx @@ -0,0 +1,144 @@ +import type { ReactNode } from 'react' +import type { DocExtractorNodeType } from '../types' +import type { PanelProps } from '@/types/workflow' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { LanguagesSupported } from '@/i18n-config/language' +import { BlockEnum } from '../../../types' +import Panel from '../panel' +import useConfig from '../use-config' + +let mockLocale = 'en-US' + +vi.mock('@/app/components/workflow/nodes/_base/components/field', () => ({ + __esModule: true, + default: ({ title, children }: { title: ReactNode, children: ReactNode }) => ( +
+
{title}
+ {children} +
+ ), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/output-vars', () => ({ + __esModule: true, + default: ({ children }: { children: ReactNode }) =>
{children}
, + VarItem: ({ name, type }: { name: string, type: string }) =>
{`${name}:${type}`}
, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/split', () => ({ + __esModule: true, + default: () =>
split
, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => ({ + __esModule: true, + default: ({ + onChange, + }: { + onChange: (value: string[]) => void + }) => , +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-help-link', () => ({ + useNodeHelpLink: () => 'https://docs.example.com/document-extractor', +})) + +vi.mock('@/service/use-common', () => ({ + useFileSupportTypes: () => ({ + data: { + allowed_extensions: ['PDF', 'md', 'md', 'DOCX'], + }, + }), +})) + +vi.mock('@/context/i18n', () => ({ + useLocale: () => mockLocale, +})) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: vi.fn(), +})) + +const mockUseConfig = vi.mocked(useConfig) + +const createData = (overrides: Partial = {}): DocExtractorNodeType => ({ + title: 'Document Extractor', + desc: '', + type: BlockEnum.DocExtractor, + variable_selector: ['node-1', 'files'], + is_array_file: false, + ...overrides, +}) + +const createConfigResult = (overrides: Partial> = {}): ReturnType => ({ + readOnly: false, + inputs: createData(), + handleVarChanges: vi.fn(), + filterVar: () => true, + ...overrides, +}) + +const panelProps: PanelProps = { + getInputVars: vi.fn(() => []), + toVarInputs: vi.fn(() => []), + runInputData: {}, + runInputDataRef: { current: {} }, + setRunInputData: vi.fn(), + runResult: null, +} + +describe('document-extractor/panel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockLocale = 'en-US' + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('wires variable changes and renders supported file types for english locales', async () => { + const user = userEvent.setup() + const handleVarChanges = vi.fn() + + mockUseConfig.mockReturnValueOnce(createConfigResult({ + inputs: createData({ is_array_file: false }), + handleVarChanges, + })) + + render( + , + ) + + await user.click(screen.getByRole('button', { name: 'pick-file-var' })) + + expect(handleVarChanges).toHaveBeenCalledWith(['node-1', 'files']) + expect(screen.getByText('workflow.nodes.docExtractor.supportFileTypes:{"types":"pdf, markdown, docx"}')).toBeInTheDocument() + expect(screen.getByRole('link', { name: 'workflow.nodes.docExtractor.learnMore' })).toHaveAttribute( + 'href', + 'https://docs.example.com/document-extractor', + ) + expect(screen.getByText('text:string')).toBeInTheDocument() + }) + + it('uses chinese separators and array output types when the input is an array of files', () => { + mockLocale = LanguagesSupported[1] + mockUseConfig.mockReturnValueOnce(createConfigResult({ + inputs: createData({ is_array_file: true }), + })) + + render( + , + ) + + expect(screen.getByText('workflow.nodes.docExtractor.supportFileTypes:{"types":"pdf、 markdown、 docx"}')).toBeInTheDocument() + expect(screen.getByText('text:array[string]')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/document-extractor/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/document-extractor/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..d988b2751d --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/__tests__/use-config.spec.ts @@ -0,0 +1,100 @@ +import type { DocExtractorNodeType } from '../types' +import { renderHook } from '@testing-library/react' +import { useStoreApi } from 'reactflow' +import { + useIsChatMode, + useNodesReadOnly, + useWorkflow, + useWorkflowVariables, +} from '@/app/components/workflow/hooks' +import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import useConfig from '../use-config' + +const mockUseStoreApi = vi.mocked(useStoreApi) +const mockUseNodesReadOnly = vi.mocked(useNodesReadOnly) +const mockUseNodeCrud = vi.mocked(useNodeCrud) +const mockUseIsChatMode = vi.mocked(useIsChatMode) +const mockUseWorkflow = vi.mocked(useWorkflow) +const mockUseWorkflowVariables = vi.mocked(useWorkflowVariables) + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useStoreApi: vi.fn(), + } +}) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useIsChatMode: vi.fn(), + useNodesReadOnly: vi.fn(), + useWorkflow: vi.fn(), + useWorkflowVariables: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +const setInputs = vi.fn() +const getCurrentVariableType = vi.fn() + +const createData = (overrides: Partial = {}): DocExtractorNodeType => ({ + title: 'Document Extractor', + desc: '', + type: BlockEnum.DocExtractor, + variable_selector: ['node-1', 'files'], + is_array_file: false, + ...overrides, +}) + +describe('document-extractor/use-config', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false, getNodesReadOnly: () => false }) + mockUseIsChatMode.mockReturnValue(false) + mockUseWorkflow.mockReturnValue({ + getBeforeNodesInSameBranch: vi.fn(() => [{ id: 'start-node' }]), + } as unknown as ReturnType) + mockUseWorkflowVariables.mockReturnValue({ + getCurrentVariableType, + } as unknown as ReturnType) + mockUseStoreApi.mockReturnValue({ + getState: () => ({ + getNodes: () => [ + { id: 'doc-node', parentId: 'loop-1', data: { type: BlockEnum.DocExtractor } }, + { id: 'loop-1', data: { type: BlockEnum.Loop } }, + ], + }), + } as ReturnType) + mockUseNodeCrud.mockReturnValue({ + inputs: createData(), + setInputs, + } as ReturnType) + }) + + it('updates the selected variable and tracks array file output types', () => { + getCurrentVariableType.mockReturnValue(VarType.arrayFile) + + const { result } = renderHook(() => useConfig('doc-node', createData())) + + result.current.handleVarChanges(['node-2', 'files']) + + expect(getCurrentVariableType).toHaveBeenCalled() + expect(setInputs).toHaveBeenCalledWith(expect.objectContaining({ + variable_selector: ['node-2', 'files'], + is_array_file: true, + })) + }) + + it('only accepts file variables in the picker filter', () => { + const { result } = renderHook(() => useConfig('doc-node', createData())) + + expect(result.current.readOnly).toBe(false) + expect(result.current.filterVar({ type: VarType.file } as never)).toBe(true) + expect(result.current.filterVar({ type: VarType.arrayFile } as never)).toBe(true) + expect(result.current.filterVar({ type: VarType.string } as never)).toBe(false) + }) +}) diff --git a/web/app/components/workflow/nodes/document-extractor/__tests__/use-single-run-form-params.spec.ts b/web/app/components/workflow/nodes/document-extractor/__tests__/use-single-run-form-params.spec.ts new file mode 100644 index 0000000000..935118f26e --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/__tests__/use-single-run-form-params.spec.ts @@ -0,0 +1,43 @@ +import type { DocExtractorNodeType } from '../types' +import { renderHook } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import useSingleRunFormParams from '../use-single-run-form-params' + +const createData = (overrides: Partial = {}): DocExtractorNodeType => ({ + title: 'Document Extractor', + desc: '', + type: BlockEnum.DocExtractor, + variable_selector: ['start', 'files'], + is_array_file: false, + ...overrides, +}) + +describe('document-extractor/use-single-run-form-params', () => { + it('exposes a single files form and updates run input values', () => { + const setRunInputData = vi.fn() + + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'doc-node', + payload: createData(), + runInputData: { files: ['old-file'] }, + runInputDataRef: { current: {} }, + getInputVars: () => [], + setRunInputData, + toVarInputs: () => [], + })) + + expect(result.current.forms).toHaveLength(1) + expect(result.current.forms[0].inputs).toEqual([ + expect.objectContaining({ + variable: 'files', + required: true, + }), + ]) + + result.current.forms[0].onChange({ files: ['new-file'] }) + + expect(setRunInputData).toHaveBeenCalledWith({ files: ['new-file'] }) + expect(result.current.getDependentVars()).toEqual([['start', 'files']]) + expect(result.current.getDependentVar('files')).toEqual(['start', 'files']) + }) +}) diff --git a/web/app/components/workflow/nodes/end/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/end/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..b4218e338b --- /dev/null +++ b/web/app/components/workflow/nodes/end/__tests__/panel.spec.tsx @@ -0,0 +1,58 @@ +import type { EndNodeType } from '../types' +import type { PanelProps } from '@/types/workflow' +import { fireEvent, render, screen } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import Panel from '../panel' + +const mockUseConfig = vi.hoisted(() => vi.fn()) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +const createData = (overrides: Partial = {}): EndNodeType => ({ + title: 'End', + desc: '', + type: BlockEnum.End, + outputs: [], + ...overrides, +}) + +describe('EndPanel', () => { + const handleVarListChange = vi.fn() + const handleAddVariable = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseConfig.mockReturnValue({ + readOnly: false, + inputs: createData(), + handleVarListChange, + handleAddVariable, + }) + }) + + it('should show the output field and allow adding output variables when writable', () => { + render() + + expect(screen.getByText('workflow.nodes.end.output.variable')).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('add-button')) + + expect(handleAddVariable).toHaveBeenCalledTimes(1) + }) + + it('should hide the add action when the node is read-only', () => { + mockUseConfig.mockReturnValue({ + readOnly: true, + inputs: createData(), + handleVarListChange, + handleAddVariable, + }) + + render() + + expect(screen.queryByTestId('add-button')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/end/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/end/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..8d0cbff547 --- /dev/null +++ b/web/app/components/workflow/nodes/end/__tests__/use-config.spec.ts @@ -0,0 +1,76 @@ +import type { EndNodeType } from '../types' +import { act, renderHook } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import useConfig from '../use-config' + +const mockUseNodesReadOnly = vi.hoisted(() => vi.fn()) +const mockUseNodeCrud = vi.hoisted(() => vi.fn()) +const mockUseVarList = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: () => mockUseNodesReadOnly(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseNodeCrud(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseVarList(...args), +})) + +const createPayload = (overrides: Partial = {}): EndNodeType => ({ + title: 'End', + desc: '', + type: BlockEnum.End, + outputs: [], + ...overrides, +}) + +describe('end/use-config', () => { + const mockHandleVarListChange = vi.fn() + const mockHandleAddVariable = vi.fn() + const mockSetInputs = vi.fn() + let currentInputs: EndNodeType + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createPayload() + + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: true }) + mockUseNodeCrud.mockReturnValue({ + inputs: currentInputs, + setInputs: mockSetInputs, + }) + mockUseVarList.mockReturnValue({ + handleVarListChange: mockHandleVarListChange, + handleAddVariable: mockHandleAddVariable, + }) + }) + + it('should build var-list handlers against outputs and surface the readonly state', () => { + const { result } = renderHook(() => useConfig('end-node', currentInputs)) + const config = mockUseVarList.mock.calls[0][0] as { setInputs: (inputs: EndNodeType) => void } + + expect(mockUseVarList).toHaveBeenCalledWith(expect.objectContaining({ + inputs: currentInputs, + setInputs: expect.any(Function), + varKey: 'outputs', + })) + expect(result.current.readOnly).toBe(true) + expect(result.current.handleVarListChange).toBe(mockHandleVarListChange) + expect(result.current.handleAddVariable).toBe(mockHandleAddVariable) + + act(() => { + config.setInputs(createPayload({ + outputs: currentInputs.outputs, + })) + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + outputs: currentInputs.outputs, + })) + }) +}) diff --git a/web/app/components/workflow/nodes/http/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/http/__tests__/node.spec.tsx new file mode 100644 index 0000000000..428aabd99e --- /dev/null +++ b/web/app/components/workflow/nodes/http/__tests__/node.spec.tsx @@ -0,0 +1,67 @@ +import type { HttpNodeType } from '../types' +import { render, screen } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import Node from '../node' +import { AuthorizationType, BodyType, Method } from '../types' + +const mockReadonlyInputWithSelectVar = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/nodes/_base/components/readonly-input-with-select-var', () => ({ + __esModule: true, + default: (props: { value: string, nodeId: string, className?: string }) => { + mockReadonlyInputWithSelectVar(props) + return
{props.value}
+ }, +})) + +const createData = (overrides: Partial = {}): HttpNodeType => ({ + title: 'HTTP Request', + desc: '', + type: BlockEnum.HttpRequest, + variables: [], + method: Method.get, + url: 'https://api.example.com', + authorization: { type: AuthorizationType.none }, + headers: '', + params: '', + body: { type: BodyType.none, data: [] }, + timeout: { connect: 5, read: 10, write: 15 }, + ssl_verify: true, + ...overrides, +}) + +describe('http/node', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('renders the request method and forwards the URL to the readonly input', () => { + render( + , + ) + + expect(screen.getByText('post')).toBeInTheDocument() + expect(screen.getByTestId('readonly-input')).toHaveTextContent('https://api.example.com/users') + expect(mockReadonlyInputWithSelectVar).toHaveBeenCalledWith(expect.objectContaining({ + nodeId: 'http-node', + value: 'https://api.example.com/users', + })) + }) + + it('renders nothing when the request URL is empty', () => { + const { container } = render( + , + ) + + expect(container).toBeEmptyDOMElement() + }) +}) diff --git a/web/app/components/workflow/nodes/http/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/http/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..e8ce5ac5c3 --- /dev/null +++ b/web/app/components/workflow/nodes/http/__tests__/panel.spec.tsx @@ -0,0 +1,295 @@ +import type { ReactNode } from 'react' +import type { HttpNodeType } from '../types' +import type { NodePanelProps } from '@/app/components/workflow/types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { BlockEnum } from '@/app/components/workflow/types' +import Panel from '../panel' +import { AuthorizationType, BodyPayloadValueType, BodyType, Method } from '../types' + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockAuthorizationModal = vi.hoisted(() => vi.fn()) +const mockCurlPanel = vi.hoisted(() => vi.fn()) +const mockApiInput = vi.hoisted(() => vi.fn()) +const mockKeyValue = vi.hoisted(() => vi.fn()) +const mockEditBody = vi.hoisted(() => vi.fn()) +const mockTimeout = vi.hoisted(() => vi.fn()) + +type ApiInputProps = { + method: Method + url: string + onMethodChange: (method: Method) => void + onUrlChange: (url: string) => void +} + +type KeyValueProps = { + nodeId: string + list: Array<{ key: string, value: string }> + onChange: (value: Array<{ key: string, value: string }>) => void + onAdd: () => void +} + +type EditBodyProps = { + payload: HttpNodeType['body'] + onChange: (value: HttpNodeType['body']) => void +} + +type TimeoutProps = { + payload: HttpNodeType['timeout'] + onChange: (value: HttpNodeType['timeout']) => void +} + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('../components/authorization', () => ({ + __esModule: true, + default: (props: { nodeId: string, payload: HttpNodeType['authorization'], onChange: (value: HttpNodeType['authorization']) => void, onHide: () => void }) => { + mockAuthorizationModal(props) + return
{props.nodeId}
+ }, +})) + +vi.mock('../components/curl-panel', () => ({ + __esModule: true, + default: (props: { nodeId: string, onHide: () => void, handleCurlImport: (node: HttpNodeType) => void }) => { + mockCurlPanel(props) + return
{props.nodeId}
+ }, +})) + +vi.mock('../components/api-input', () => ({ + __esModule: true, + default: (props: ApiInputProps) => { + mockApiInput(props) + return ( +
+
{`${props.method}:${props.url}`}
+ + +
+ ) + }, +})) + +vi.mock('../components/key-value', () => ({ + __esModule: true, + default: (props: KeyValueProps) => { + mockKeyValue(props) + return ( +
+
{props.list.map(item => `${item.key}:${item.value}`).join(',')}
+ + +
+ ) + }, +})) + +vi.mock('../components/edit-body', () => ({ + __esModule: true, + default: (props: EditBodyProps) => { + mockEditBody(props) + return ( + + ) + }, +})) + +vi.mock('../components/timeout', () => ({ + __esModule: true, + default: (props: TimeoutProps) => { + mockTimeout(props) + return ( + + ) + }, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/output-vars', () => ({ + __esModule: true, + default: ({ children }: { children: ReactNode }) =>
{children}
, + VarItem: ({ name, type }: { name: string, type: string }) =>
{`${name}:${type}`}
, +})) + +const createData = (overrides: Partial = {}): HttpNodeType => ({ + title: 'HTTP Request', + desc: '', + type: BlockEnum.HttpRequest, + variables: [], + method: Method.get, + url: 'https://api.example.com', + authorization: { type: AuthorizationType.none }, + headers: '', + params: '', + body: { type: BodyType.none, data: [] }, + timeout: { connect: 5, read: 10, write: 15 }, + ssl_verify: true, + ...overrides, +}) + +const panelProps = {} as NodePanelProps['panelProps'] + +describe('http/panel', () => { + const handleMethodChange = vi.fn() + const handleUrlChange = vi.fn() + const setHeaders = vi.fn() + const addHeader = vi.fn() + const setParams = vi.fn() + const addParam = vi.fn() + const setBody = vi.fn() + const showAuthorization = vi.fn() + const hideAuthorization = vi.fn() + const setAuthorization = vi.fn() + const setTimeout = vi.fn() + const showCurlPanel = vi.fn() + const hideCurlPanel = vi.fn() + const handleCurlImport = vi.fn() + const handleSSLVerifyChange = vi.fn() + + const createConfigResult = (overrides: Record = {}) => ({ + readOnly: false, + isDataReady: true, + inputs: createData({ + authorization: { type: AuthorizationType.apiKey, config: null }, + }), + handleMethodChange, + handleUrlChange, + headers: [{ key: 'accept', value: 'application/json' }], + setHeaders, + addHeader, + params: [{ key: 'page', value: '1' }], + setParams, + addParam, + setBody, + isShowAuthorization: false, + showAuthorization, + hideAuthorization, + setAuthorization, + setTimeout, + isShowCurlPanel: false, + showCurlPanel, + hideCurlPanel, + handleCurlImport, + handleSSLVerifyChange, + ...overrides, + }) + + beforeEach(() => { + vi.clearAllMocks() + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('renders request fields, forwards child changes, and wires header operations', async () => { + const user = userEvent.setup() + + render( + , + ) + + expect(screen.getByText('get:https://api.example.com')).toBeInTheDocument() + expect(screen.getByText('body:string')).toBeInTheDocument() + expect(screen.getByText('status_code:number')).toBeInTheDocument() + expect(screen.getByText('headers:object')).toBeInTheDocument() + expect(screen.getByText('files:Array[File]')).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'emit-method-change' })) + await user.click(screen.getByRole('button', { name: 'emit-url-change' })) + await user.click(screen.getAllByRole('button', { name: 'emit-key-value-change' })[0]!) + await user.click(screen.getAllByRole('button', { name: 'emit-key-value-add' })[0]!) + await user.click(screen.getAllByRole('button', { name: 'emit-key-value-change' })[1]!) + await user.click(screen.getAllByRole('button', { name: 'emit-key-value-add' })[1]!) + await user.click(screen.getByRole('button', { name: 'emit-body-change' })) + await user.click(screen.getByRole('button', { name: 'emit-timeout-change' })) + await user.click(screen.getByText('workflow.nodes.http.authorization.authorization')) + await user.click(screen.getByText('workflow.nodes.http.curl.title')) + await user.click(screen.getByRole('switch')) + + expect(handleMethodChange).toHaveBeenCalledWith(Method.post) + expect(handleUrlChange).toHaveBeenCalledWith('https://changed.example.com') + expect(setHeaders).toHaveBeenCalledWith([{ key: 'x-token', value: '123' }]) + expect(addHeader).toHaveBeenCalledTimes(1) + expect(setParams).toHaveBeenCalledWith([{ key: 'x-token', value: '123' }]) + expect(addParam).toHaveBeenCalledTimes(1) + expect(setBody).toHaveBeenCalledWith({ + type: BodyType.json, + data: [{ type: 'text', value: '{"hello":"world"}' }], + }) + expect(setTimeout).toHaveBeenCalledWith(expect.objectContaining({ connect: 9 })) + expect(showAuthorization).toHaveBeenCalledTimes(1) + expect(showCurlPanel).toHaveBeenCalledTimes(1) + expect(handleSSLVerifyChange).toHaveBeenCalledWith(false) + expect(mockApiInput).toHaveBeenCalledWith(expect.objectContaining({ + method: Method.get, + url: 'https://api.example.com', + })) + }) + + it('returns null before the config data is ready', () => { + mockUseConfig.mockReturnValueOnce(createConfigResult({ isDataReady: false })) + + const { container } = render( + , + ) + + expect(container).toBeEmptyDOMElement() + }) + + it('renders auth and curl panels only when writable and toggled on', () => { + mockUseConfig.mockReturnValueOnce(createConfigResult({ + isShowAuthorization: true, + isShowCurlPanel: true, + })) + + const { rerender } = render( + , + ) + + expect(screen.getByTestId('authorization-modal')).toHaveTextContent('http-node') + expect(screen.getByTestId('curl-panel')).toHaveTextContent('http-node') + + mockUseConfig.mockReturnValueOnce(createConfigResult({ + readOnly: true, + isShowAuthorization: true, + isShowCurlPanel: true, + })) + + rerender( + , + ) + + expect(screen.queryByTestId('authorization-modal')).not.toBeInTheDocument() + expect(screen.queryByTestId('curl-panel')).not.toBeInTheDocument() + expect(screen.getByRole('switch')).toHaveAttribute('aria-disabled', 'true') + }) +}) diff --git a/web/app/components/workflow/nodes/http/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/http/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..e771122e28 --- /dev/null +++ b/web/app/components/workflow/nodes/http/__tests__/use-config.spec.ts @@ -0,0 +1,271 @@ +import type { HttpNodeType } from '../types' +import { act, renderHook, waitFor } from '@testing-library/react' +import { useNodesReadOnly } from '@/app/components/workflow/hooks' +import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' +import { useStore } from '@/app/components/workflow/store' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import useVarList from '../../_base/hooks/use-var-list' +import useKeyValueList from '../hooks/use-key-value-list' +import { APIType, AuthorizationType, BodyPayloadValueType, BodyType, Method } from '../types' +import useConfig from '../use-config' + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-var-list', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('../hooks/use-key-value-list', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: vi.fn(), +})) + +const mockUseNodesReadOnly = vi.mocked(useNodesReadOnly) +const mockUseNodeCrud = vi.mocked(useNodeCrud) +const mockUseVarList = vi.mocked(useVarList) +const mockUseKeyValueList = vi.mocked(useKeyValueList) +const mockUseStore = vi.mocked(useStore) + +const createPayload = (overrides: Partial = {}): HttpNodeType => ({ + title: 'HTTP Request', + desc: '', + type: BlockEnum.HttpRequest, + variables: [], + method: Method.get, + url: 'https://api.example.com', + authorization: { type: AuthorizationType.none }, + headers: 'accept:application/json', + params: 'page:1', + body: { + type: BodyType.json, + data: '{"name":"alice"}', + }, + timeout: { connect: 5, read: 10, write: 15 }, + ssl_verify: true, + ...overrides, +}) + +describe('http/use-config', () => { + const mockSetInputs = vi.fn() + const mockHandleVarListChange = vi.fn() + const mockHandleAddVariable = vi.fn() + const headerSetList = vi.fn() + const headerAddItem = vi.fn() + const headerToggle = vi.fn() + const paramSetList = vi.fn() + const paramAddItem = vi.fn() + const paramToggle = vi.fn() + let currentInputs: HttpNodeType + let headerFieldChange: ((value: string) => void) | undefined + let paramFieldChange: ((value: string) => void) | undefined + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createPayload() + headerFieldChange = undefined + paramFieldChange = undefined + + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false, getNodesReadOnly: () => false }) + mockUseNodeCrud.mockImplementation(() => ({ + inputs: currentInputs, + setInputs: mockSetInputs, + })) + mockUseVarList.mockReturnValue({ + handleVarListChange: mockHandleVarListChange, + handleAddVariable: mockHandleAddVariable, + } as ReturnType) + mockUseKeyValueList.mockImplementation((value, onChange) => { + if (value === currentInputs.headers) { + headerFieldChange = onChange + return { + list: [{ id: 'header-1', key: 'accept', value: 'application/json' }], + setList: headerSetList, + addItem: headerAddItem, + isKeyValueEdit: true, + toggleIsKeyValueEdit: headerToggle, + } + } + + paramFieldChange = onChange + return { + list: [{ id: 'param-1', key: 'page', value: '1' }], + setList: paramSetList, + addItem: paramAddItem, + isKeyValueEdit: false, + toggleIsKeyValueEdit: paramToggle, + } + }) + mockUseStore.mockImplementation((selector) => { + const state = { + nodesDefaultConfigs: { + [BlockEnum.HttpRequest]: createPayload({ + method: Method.post, + url: 'https://default.example.com', + headers: '', + params: '', + body: { type: BodyType.none, data: [] }, + timeout: { connect: 1, read: 2, write: 3 }, + ssl_verify: false, + }), + }, + } + + return selector(state as never) + }) + }) + + it('stays pending when the node default config is unavailable', () => { + mockUseStore.mockImplementation((selector) => { + return selector({ nodesDefaultConfigs: {} } as never) + }) + + const { result } = renderHook(() => useConfig('http-node', currentInputs)) + + expect(result.current.isDataReady).toBe(false) + expect(mockSetInputs).not.toHaveBeenCalled() + }) + + it('hydrates defaults, normalizes body payloads, and exposes var-list and key-value helpers', async () => { + const { result } = renderHook(() => useConfig('http-node', currentInputs)) + + await waitFor(() => { + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + method: Method.get, + url: 'https://api.example.com', + body: { + type: BodyType.json, + data: [{ + type: BodyPayloadValueType.text, + value: '{"name":"alice"}', + }], + }, + ssl_verify: true, + })) + }) + + expect(result.current.isDataReady).toBe(true) + expect(result.current.readOnly).toBe(false) + expect(result.current.handleVarListChange).toBe(mockHandleVarListChange) + expect(result.current.handleAddVariable).toBe(mockHandleAddVariable) + expect(result.current.headers).toEqual([{ id: 'header-1', key: 'accept', value: 'application/json' }]) + expect(result.current.setHeaders).toBe(headerSetList) + expect(result.current.addHeader).toBe(headerAddItem) + expect(result.current.isHeaderKeyValueEdit).toBe(true) + expect(result.current.toggleIsHeaderKeyValueEdit).toBe(headerToggle) + expect(result.current.params).toEqual([{ id: 'param-1', key: 'page', value: '1' }]) + expect(result.current.setParams).toBe(paramSetList) + expect(result.current.addParam).toBe(paramAddItem) + expect(result.current.isParamKeyValueEdit).toBe(false) + expect(result.current.toggleIsParamKeyValueEdit).toBe(paramToggle) + expect(result.current.filterVar({ type: VarType.string } as never)).toBe(true) + expect(result.current.filterVar({ type: VarType.number } as never)).toBe(true) + expect(result.current.filterVar({ type: VarType.secret } as never)).toBe(true) + expect(result.current.filterVar({ type: VarType.file } as never)).toBe(false) + }) + + it('initializes empty body data arrays when the payload body is missing', async () => { + currentInputs = createPayload({ + body: { + type: BodyType.formData, + data: undefined as unknown as HttpNodeType['body']['data'], + }, + }) + + renderHook(() => useConfig('http-node', currentInputs)) + + await waitFor(() => { + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + body: { + type: BodyType.formData, + data: [], + }, + })) + }) + }) + + it('updates request fields, authorization state, curl imports, and ssl verification', async () => { + const { result } = renderHook(() => useConfig('http-node', currentInputs)) + + await waitFor(() => { + expect(result.current.isDataReady).toBe(true) + }) + + mockSetInputs.mockClear() + + act(() => { + result.current.handleMethodChange(Method.delete) + result.current.handleUrlChange('https://changed.example.com') + headerFieldChange?.('x-token:123') + paramFieldChange?.('size:20') + result.current.setBody({ type: BodyType.rawText, data: 'raw payload' }) + result.current.showAuthorization() + }) + + expect(result.current.isShowAuthorization).toBe(true) + + act(() => { + result.current.hideAuthorization() + result.current.setAuthorization({ + type: AuthorizationType.apiKey, + config: { + type: APIType.bearer, + api_key: 'secret', + }, + }) + result.current.setTimeout({ connect: 30, read: 40, write: 50 }) + result.current.showCurlPanel() + }) + + expect(result.current.isShowCurlPanel).toBe(true) + + act(() => { + result.current.hideCurlPanel() + result.current.handleCurlImport(createPayload({ + method: Method.patch, + url: 'https://imported.example.com', + headers: 'authorization:Bearer imported', + params: 'debug:true', + body: { type: BodyType.json, data: [{ type: BodyPayloadValueType.text, value: '{"ok":true}' }] }, + })) + result.current.handleSSLVerifyChange(false) + }) + + expect(result.current.isShowAuthorization).toBe(false) + expect(result.current.isShowCurlPanel).toBe(false) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ method: Method.delete })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ url: 'https://changed.example.com' })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ headers: 'x-token:123' })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ params: 'size:20' })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + body: { type: BodyType.rawText, data: 'raw payload' }, + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + authorization: expect.objectContaining({ + type: AuthorizationType.apiKey, + }), + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + timeout: { connect: 30, read: 40, write: 50 }, + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + method: Method.patch, + url: 'https://imported.example.com', + headers: 'authorization:Bearer imported', + params: 'debug:true', + body: { type: BodyType.json, data: [{ type: BodyPayloadValueType.text, value: '{"ok":true}' }] }, + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ ssl_verify: false })) + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/human-input/__tests__/node.spec.tsx new file mode 100644 index 0000000000..915f9136be --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/__tests__/node.spec.tsx @@ -0,0 +1,83 @@ +import type { HumanInputNodeType } from '../types' +import { render, screen } from '@testing-library/react' +import { BlockEnum, InputVarType } from '@/app/components/workflow/types' +import Node from '../node' +import { DeliveryMethodType, UserActionButtonType } from '../types' + +vi.mock('../../_base/components/node-handle', () => ({ + NodeSourceHandle: (props: { handleId: string }) =>
{`handle:${props.handleId}`}
, +})) + +const createData = (overrides: Partial = {}): HumanInputNodeType => ({ + title: 'Human Input', + desc: '', + type: BlockEnum.HumanInput, + delivery_methods: [{ + id: 'dm-webapp', + type: DeliveryMethodType.WebApp, + enabled: true, + }, { + id: 'dm-email', + type: DeliveryMethodType.Email, + enabled: true, + }], + form_content: 'Please review this request', + inputs: [{ + type: InputVarType.textInput, + output_variable_name: 'review_result', + default: { + selector: [], + type: 'constant', + value: '', + }, + }], + user_actions: [{ + id: 'approve', + title: 'Approve', + button_style: UserActionButtonType.Primary, + }, { + id: 'reject', + title: 'Reject', + button_style: UserActionButtonType.Default, + }], + timeout: 3, + timeout_unit: 'day', + ...overrides, +}) + +describe('human-input/node', () => { + it('renders delivery methods, user action handles, and the timeout handle', () => { + render( + , + ) + + expect(screen.getByText('workflow.nodes.humanInput.deliveryMethod.title')).toBeInTheDocument() + expect(screen.getByText('webapp')).toBeInTheDocument() + expect(screen.getByText('email')).toBeInTheDocument() + expect(screen.getByText('approve')).toBeInTheDocument() + expect(screen.getByText('reject')).toBeInTheDocument() + expect(screen.getByText('Timeout')).toBeInTheDocument() + expect(screen.getByText('handle:approve')).toBeInTheDocument() + expect(screen.getByText('handle:reject')).toBeInTheDocument() + expect(screen.getByText('handle:__timeout')).toBeInTheDocument() + }) + + it('keeps the timeout handle when delivery methods and actions are empty', () => { + render( + , + ) + + expect(screen.queryByText('workflow.nodes.humanInput.deliveryMethod.title')).not.toBeInTheDocument() + expect(screen.getByText('Timeout')).toBeInTheDocument() + expect(screen.getByText('handle:__timeout')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/human-input/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..937a2da61a --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/__tests__/panel.spec.tsx @@ -0,0 +1,386 @@ +import type { ReactNode } from 'react' +import type useConfig from '../hooks/use-config' +import type { HumanInputNodeType } from '../types' +import type { NodePanelProps } from '@/app/components/workflow/types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import copy from 'copy-to-clipboard' +import { toast } from '@/app/components/base/ui/toast' +import { BlockEnum, InputVarType, VarType } from '@/app/components/workflow/types' +import Panel from '../panel' +import { DeliveryMethodType, UserActionButtonType } from '../types' + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockUseStore = vi.hoisted(() => vi.fn()) +const mockUseAvailableVarList = vi.hoisted(() => vi.fn()) +const mockDeliveryMethod = vi.hoisted(() => vi.fn()) +const mockFormContent = vi.hoisted(() => vi.fn()) +const mockFormContentPreview = vi.hoisted(() => vi.fn()) +const mockTimeoutInput = vi.hoisted(() => vi.fn()) +const mockUserActionItem = vi.hoisted(() => vi.fn()) + +vi.mock('copy-to-clipboard', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: vi.fn(), + }, +})) + +vi.mock('@/app/components/base/tooltip', () => ({ + __esModule: true, + default: () =>
tooltip
, +})) + +vi.mock('@/app/components/base/action-button', () => ({ + __esModule: true, + default: (props: { + children: ReactNode + onClick: () => void + }) => ( + + ), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { nodePanelWidth: number }) => unknown) => mockUseStore(selector), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-available-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseAvailableVarList(...args), +})) + +vi.mock('../hooks/use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('../components/delivery-method', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + onChange: (methods: HumanInputNodeType['delivery_methods']) => void + }) => { + mockDeliveryMethod(props) + return ( + + ) + }, +})) + +vi.mock('../components/form-content', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + isExpand: boolean + onChange: (value: string) => void + onFormInputsChange: (value: HumanInputNodeType['inputs']) => void + onFormInputItemRename: (oldName: string, newName: string) => void + onFormInputItemRemove: (name: string) => void + }) => { + mockFormContent(props) + return ( +
+
{props.readonly ? 'form-content:readonly' : `form-content:${props.isExpand ? 'expanded' : 'collapsed'}`}
+ + + + +
+ ) + }, +})) + +vi.mock('../components/form-content-preview', () => ({ + __esModule: true, + default: (props: { + onClose: () => void + }) => { + mockFormContentPreview(props) + return ( +
+
form-preview
+ +
+ ) + }, +})) + +vi.mock('../components/timeout', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + onChange: (value: { timeout: number, unit: 'hour' | 'day' }) => void + }) => { + mockTimeoutInput(props) + return ( + + ) + }, +})) + +vi.mock('../components/user-action', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + data: HumanInputNodeType['user_actions'][number] + onChange: (value: HumanInputNodeType['user_actions'][number]) => void + onDelete: (id: string) => void + }) => { + mockUserActionItem(props) + return ( +
+
{`${props.data.id}:${props.readonly ? 'readonly' : 'editable'}`}
+ + +
+ ) + }, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/output-vars', () => ({ + __esModule: true, + default: (props: { + children: ReactNode + collapsed?: boolean + onCollapse?: (collapsed: boolean) => void + }) => ( +
+ + {props.children} +
+ ), + VarItem: ({ name, type, description }: { name: string, type: string, description: string }) => ( +
{`${name}:${type}:${description}`}
+ ), +})) + +vi.mock('@remixicon/react', () => ({ + RiAddLine: () => add-icon, + RiClipboardLine: () => clipboard-icon, + RiCollapseDiagonalLine: () => collapse-icon, + RiExpandDiagonalLine: () => expand-icon, + RiEyeLine: () => preview-icon, +})) + +const mockCopy = vi.mocked(copy) +const mockToastSuccess = vi.mocked(toast.success) + +const createData = (overrides: Partial = {}): HumanInputNodeType => ({ + title: 'Human Input', + desc: '', + type: BlockEnum.HumanInput, + delivery_methods: [{ + id: 'dm-webapp', + type: DeliveryMethodType.WebApp, + enabled: true, + }], + form_content: 'Please review this request', + inputs: [{ + type: InputVarType.textInput, + output_variable_name: 'review_result', + default: { + selector: [], + type: 'constant', + value: '', + }, + }], + user_actions: [{ + id: 'approve', + title: 'Approve', + button_style: UserActionButtonType.Primary, + }], + timeout: 3, + timeout_unit: 'day', + ...overrides, +}) + +const createConfigResult = (overrides: Partial> = {}): ReturnType => ({ + readOnly: false, + inputs: createData(), + handleDeliveryMethodChange: vi.fn(), + handleUserActionAdd: vi.fn(), + handleUserActionChange: vi.fn(), + handleUserActionDelete: vi.fn(), + handleTimeoutChange: vi.fn(), + handleFormContentChange: vi.fn(), + handleFormInputsChange: vi.fn(), + handleFormInputItemRename: vi.fn(), + handleFormInputItemRemove: vi.fn(), + editorKey: 1, + structuredOutputCollapsed: true, + setStructuredOutputCollapsed: vi.fn(), + ...overrides, +}) + +const renderPanel = (data: HumanInputNodeType = createData()) => { + const props: NodePanelProps = { + id: 'human-input-node', + data, + panelProps: { + getInputVars: vi.fn(() => []), + toVarInputs: vi.fn(() => []), + runInputData: {}, + runInputDataRef: { current: {} }, + setRunInputData: vi.fn(), + runResult: null, + }, + } + + return render() +} + +describe('human-input/panel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseStore.mockImplementation(selector => selector({ nodePanelWidth: 480 })) + mockUseAvailableVarList.mockImplementation((_id, options?: { filterVar?: (payload: { type: VarType }) => boolean }) => ({ + availableVars: [{ + variable: ['start', 'email'], + type: VarType.string, + }, { + variable: ['start', 'files'], + type: VarType.file, + }].filter(variable => options?.filterVar ? options.filterVar({ type: variable.type } as never) : true), + availableNodesWithParent: [{ + id: 'start-node', + data: { + title: 'Start', + type: BlockEnum.Start, + }, + }], + })) + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('renders editable controls, forwards updates, and toggles preview and output sections', async () => { + const user = userEvent.setup() + const config = createConfigResult() + mockUseConfig.mockReturnValue(config) + + const { container } = renderPanel() + + expect(screen.getByRole('button', { name: 'delivery-method:editable' })).toBeInTheDocument() + expect(screen.getByText('form-content:collapsed')).toBeInTheDocument() + expect(screen.getByText('approve:editable')).toBeInTheDocument() + expect(screen.getByText('review_result:string:Form input value')).toBeInTheDocument() + expect(screen.getByText('__action_id:string:Action ID user triggered')).toBeInTheDocument() + expect(screen.getByText('__rendered_content:string:Rendered content')).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'delivery-method:editable' })) + await user.click(screen.getByRole('button', { name: /workflow\.nodes\.humanInput\.formContent\.preview/ })) + await user.click(screen.getByRole('button', { name: 'change-form-content' })) + await user.click(screen.getByRole('button', { name: 'change-form-inputs' })) + await user.click(screen.getByRole('button', { name: 'rename-form-input' })) + await user.click(screen.getByRole('button', { name: 'remove-form-input' })) + await user.click(screen.getByRole('button', { name: 'action-button' })) + await user.click(screen.getByRole('button', { name: 'change-action-approve' })) + await user.click(screen.getByRole('button', { name: 'delete-action-approve' })) + await user.click(screen.getByRole('button', { name: 'timeout:editable' })) + await user.click(screen.getByRole('button', { name: 'toggle-output-vars' })) + await user.click(screen.getByRole('button', { name: 'close-preview' })) + + const iconContainers = container.querySelectorAll('div.flex.size-6.cursor-pointer') + await user.click(iconContainers[0] as HTMLElement) + await user.click(iconContainers[1] as HTMLElement) + + expect(config.handleDeliveryMethodChange).toHaveBeenCalledWith([{ + id: 'dm-email', + type: DeliveryMethodType.Email, + enabled: true, + }]) + expect(config.handleFormContentChange).toHaveBeenCalledWith('Updated content') + expect(config.handleFormInputsChange).toHaveBeenCalled() + expect(config.handleFormInputItemRename).toHaveBeenCalledWith('name', 'email') + expect(config.handleFormInputItemRemove).toHaveBeenCalledWith('name') + expect(config.handleUserActionAdd).toHaveBeenCalledWith({ + id: 'action_2', + title: 'Button Text 2', + button_style: UserActionButtonType.Default, + }) + expect(config.handleUserActionChange).toHaveBeenCalledWith(0, { + id: 'approve', + title: 'Approve updated', + button_style: UserActionButtonType.Primary, + }) + expect(config.handleUserActionDelete).toHaveBeenCalledWith('approve') + expect(config.handleTimeoutChange).toHaveBeenCalledWith({ timeout: 8, unit: 'hour' }) + expect(config.setStructuredOutputCollapsed).toHaveBeenCalledWith(false) + expect(mockCopy).toHaveBeenCalledWith('Please review this request') + expect(mockToastSuccess).toHaveBeenCalledWith('common.actionMsg.copySuccessfully') + expect(mockFormContentPreview).toHaveBeenCalled() + }) + + it('renders readonly and empty states without preview or add controls', () => { + mockUseConfig.mockReturnValue(createConfigResult({ + readOnly: true, + inputs: createData({ + user_actions: [], + }), + structuredOutputCollapsed: false, + })) + + renderPanel() + + expect(screen.getByRole('button', { name: 'delivery-method:readonly' })).toBeInTheDocument() + expect(screen.getByText('form-content:readonly')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.humanInput.userActions.emptyTip')).toBeInTheDocument() + expect(screen.queryByRole('button', { name: /workflow\.nodes\.humanInput\.formContent\.preview/ })).not.toBeInTheDocument() + expect(screen.queryByRole('button', { name: 'action-button' })).not.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'timeout:readonly' })).toBeInTheDocument() + expect(screen.queryByText('form-preview')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/components/delivery-method/__tests__/email-configure-modal.spec.tsx b/web/app/components/workflow/nodes/human-input/components/delivery-method/__tests__/email-configure-modal.spec.tsx new file mode 100644 index 0000000000..cec9ffe69a --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/components/delivery-method/__tests__/email-configure-modal.spec.tsx @@ -0,0 +1,180 @@ +import type { EmailConfig } from '../../../types' +import { fireEvent, render, screen } from '@testing-library/react' +import EmailConfigureModal from '../email-configure-modal' + +const mockToastError = vi.hoisted(() => vi.fn()) +const mockUseAppContextSelector = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: (message: string) => mockToastError(message), + }, +})) + +vi.mock('@/context/app-context', () => ({ + useSelector: (selector: (state: { userProfile: { email: string } }) => string) => + mockUseAppContextSelector(selector), +})) + +vi.mock('../mail-body-input', () => ({ + default: ({ value, onChange }: { value: string, onChange: (value: string) => void }) => ( +