diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index c32fc9d0cb..29f5b090f8 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -77,6 +77,8 @@ jobs: with: files: | web/** + e2e/** + sdks/nodejs-client/** packages/** package.json pnpm-lock.yaml @@ -95,14 +97,14 @@ jobs: id: eslint-cache-restore uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 with: - path: web/.eslintcache - key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + path: .eslintcache + key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} restore-keys: | - ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web + working-directory: . run: vp run lint:ci - name: Web tsslint @@ -112,7 +114,7 @@ jobs: - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web + working-directory: . run: vp run type-check - name: Web dead code check @@ -124,7 +126,7 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 with: - path: web/.eslintcache + path: .eslintcache key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} superlinter: diff --git a/.gitignore b/.gitignore index 53dea88899..3493a7c756 100644 --- a/.gitignore +++ b/.gitignore @@ -203,6 +203,7 @@ sdks/python-client/dify_client.egg-info .vscode/* !.vscode/launch.json.template +!.vscode/settings.example.json !.vscode/README.md api/.vscode # vscode Code History Extension @@ -242,3 +243,5 @@ scripts/stress-test/reports/ # Code Agent Folder .qoder/* + +.eslintcache diff --git a/.vite-hooks/pre-commit b/.vite-hooks/pre-commit index cced022568..d48381bce2 100755 --- a/.vite-hooks/pre-commit +++ b/.vite-hooks/pre-commit @@ -56,44 +56,9 @@ if $api_modified; then fi fi -if $web_modified; then - if $skip_web_checks; then - echo "Git operation in progress, skipping web checks" - exit 0 - fi - - echo "Running ESLint on web module" - - if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then - web_ts_modified=false - else - ts_diff_status=$? - if [ $ts_diff_status -eq 1 ]; then - web_ts_modified=true - else - echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)." - exit $ts_diff_status - fi - fi - - cd ./web || exit 1 - pnpm exec vp staged - - if $web_ts_modified; then - echo "Running TypeScript type-check:tsgo" - if ! npm run type-check:tsgo; then - echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors." - exit 1 - fi - else - echo "No staged TypeScript changes detected, skipping type-check:tsgo" - fi - - echo "Running knip" - if ! npm run knip; then - echo "Knip check failed. Please run 'npm run knip' to fix the errors." - exit 1 - fi - - cd ../ +if $skip_web_checks; then + echo "Git operation in progress, skipping web checks" + exit 0 fi + +vp staged diff --git a/api/commands/account.py b/api/commands/account.py index 6a2a2e0428..761323a73d 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -2,6 +2,7 @@ import base64 import secrets import click +from sqlalchemy.orm import Session from constants.languages import languages from extensions.ext_database import db @@ -43,10 +44,11 @@ def reset_password(email, new_password, password_confirm): # encrypt password with salt password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account = db.session.merge(account) - account.password = base64_password_hashed - account.password_salt = base64_salt - db.session.commit() + with Session(db.engine) as session: + account = session.merge(account) + account.password = base64_password_hashed + account.password_salt = base64_salt + session.commit() AccountService.reset_login_error_rate_limit(normalized_email) click.echo(click.style("Password reset successfully.", fg="green")) @@ -77,9 +79,10 @@ def reset_email(email, new_email, email_confirm): click.echo(click.style(f"Invalid email: {new_email}", fg="red")) return - account = db.session.merge(account) - account.email = normalized_new_email - db.session.commit() + with Session(db.engine) as session: + account = session.merge(account) + account.email = normalized_new_email + session.commit() click.echo(click.style("Email updated successfully.", fg="green")) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 23351beed9..7302a4edf5 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -126,8 +126,6 @@ from .snippets import snippet_workflow, snippet_workflow_draft_variable from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport] # Import snippet controllers -from .snippets import snippet_workflow, snippet_workflow_draft_variable - # Import tag controllers from .tag import tags @@ -215,12 +213,12 @@ __all__ = [ "setup", "site", "snippet_workflow", - "snippet_workflow_draft_variable", - "snippets", - "socketio_workflow", "snippet_workflow", "snippet_workflow_draft_variable", + "snippet_workflow_draft_variable", "snippets", + "snippets", + "socketio_workflow", "spec", "statistic", "tags", diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 214545bac8..2ac4aef311 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -5,7 +5,6 @@ from typing import Any, Literal from flask import request from flask_restx import Resource -from graphon.enums import WorkflowExecutionStatus from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -30,6 +29,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES from extensions.ext_database import db from fields.base import ResponseModel +from graphon.enums import WorkflowExecutionStatus from libs.helper import build_icon_url from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 78ddb904e1..91fbe4a85a 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -2,7 +2,6 @@ import logging from flask import request from flask_restx import Resource, fields -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -23,6 +22,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index d83925d173..fe274e4c9a 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -3,7 +3,6 @@ from typing import Any, Literal from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -27,6 +26,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 7101d5df7b..c720a5e074 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,7 +1,6 @@ from collections.abc import Sequence from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from controllers.console import console_ns @@ -20,6 +19,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 5b1abc98dc..d517f695b8 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -18,12 +18,6 @@ from models.enums import AppMCPServerStatus from models.model import AppMCPServer -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class MCPServerCreatePayload(BaseModel): description: str | None = Field(default=None, description="Server description") parameters: dict[str, Any] = Field(..., description="Server parameters configuration") @@ -36,19 +30,25 @@ class MCPServerUpdatePayload(BaseModel): status: str | None = Field(default=None, description="Server status") +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + class AppMCPServerResponse(ResponseModel): id: str name: str server_code: str description: str - status: str + status: AppMCPServerStatus parameters: dict[str, Any] | list[Any] | str created_at: int | None = None updated_at: int | None = None @field_validator("parameters", mode="before") @classmethod - def _parse_json_string(cls, value: Any) -> Any: + def _normalize_parameters(cls, value: Any) -> Any: if isinstance(value, str): try: return json.loads(value) @@ -70,7 +70,9 @@ class AppMCPServerController(Resource): @console_ns.doc("get_app_mcp_server") @console_ns.doc(description="Get MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response( + 200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @login_required @account_initialization_required @setup_required @@ -85,7 +87,9 @@ class AppMCPServerController(Resource): @console_ns.doc(description="Create MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__]) - @console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response( + 201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @console_ns.response(403, "Insufficient permissions") @account_initialization_required @get_app_model @@ -111,13 +115,15 @@ class AppMCPServerController(Resource): ) db.session.add(server) db.session.commit() - return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json") + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201 @console_ns.doc("update_app_mcp_server") @console_ns.doc(description="Update MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__]) - @console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response( + 200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @get_app_model @@ -154,7 +160,7 @@ class AppMCPServerRefreshController(Resource): @console_ns.doc("refresh_app_mcp_server") @console_ns.doc(description="Refresh MCP server configuration and regenerate server code") @console_ns.doc(params={"server_id": "Server ID"}) - @console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @setup_required diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index ec29970ba6..8571afde31 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -5,11 +5,6 @@ from typing import Any, Literal from flask import abort, request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.enums import NodeType -from graphon.file import File -from graphon.file import helpers as file_helpers -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, ValidationError, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -42,6 +37,11 @@ from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.online_user_fields import online_user_list_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from graphon.enums import NodeType +from graphon.file import File +from graphon.file import helpers as file_helpers +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 6b402898e8..4b39590235 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -4,7 +4,6 @@ from typing import Any from dateutil.parser import isoparse from flask import request from flask_restx import Resource -from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker @@ -16,6 +15,7 @@ from extensions.ext_database import db from fields.base import ResponseModel from fields.end_user_fields import SimpleEndUser from fields.member_fields import SimpleAccount +from graphon.enums import WorkflowExecutionStatus from libs.login import login_required from models import App from models.model import AppMode diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index a1a075be71..6748d95d6b 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -3,8 +3,6 @@ from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -28,6 +26,8 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index b55cda4244..727428c8e7 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -5,11 +5,11 @@ from typing import Concatenate from flask import jsonify, request from flask.typing import ResponseReturnValue from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 14ca27acbd..0b493d2c71 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -4,7 +4,6 @@ from urllib.parse import quote from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -54,6 +53,7 @@ from fields.dataset_fields import ( weighted_score_fields, ) from fields.document_fields import document_status_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 98d4ad9412..3372a967d9 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -3,20 +3,19 @@ import logging from argparse import ArgumentTypeError from collections.abc import Sequence from contextlib import ExitStack +from datetime import datetime from typing import Any, Literal, cast import sqlalchemy as sa from flask import request, send_file -from flask_restx import Resource, fields, marshal, marshal_with -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from pydantic import BaseModel, Field +from flask_restx import Resource, marshal +from pydantic import BaseModel, Field, field_validator from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload -from controllers.common.schema import get_or_create_model, register_schema_models +from controllers.common.schema import register_schema_models from controllers.console import console_ns from core.errors.error import ( LLMBadRequestError, @@ -31,14 +30,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db -from fields.dataset_fields import dataset_fields +from fields.base import ResponseModel from fields.document_fields import ( - dataset_and_document_fields, document_fields, - document_metadata_fields, document_status_fields, document_with_segments_fields, ) +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile @@ -72,27 +71,100 @@ from ..wraps import ( logger = logging.getLogger(__name__) -# Register models for flask_restx to avoid dict type issues in Swagger -dataset_model = get_or_create_model("Dataset", dataset_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields) -document_fields_copy = document_fields.copy() -document_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_model = get_or_create_model("Document", document_fields_copy) +def _normalize_enum(value: Any) -> Any: + if isinstance(value, str) or value is None: + return value + return getattr(value, "value", value) -document_with_segments_fields_copy = document_with_segments_fields.copy() -document_with_segments_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy) -dataset_and_document_fields_copy = dataset_and_document_fields.copy() -dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model) -dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model)) -dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) +class DatasetResponse(ResponseModel): + id: str + name: str + description: str | None = None + permission: str | None = None + data_source_type: str | None = None + indexing_technique: str | None = None + created_by: str | None = None + created_at: int | None = None + + @field_validator("data_source_type", "indexing_technique", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentMetadataResponse(ResponseModel): + id: str + name: str + type: str + value: str | None = None + + +class DocumentResponse(ResponseModel): + id: str + position: int | None = None + data_source_type: str | None = None + data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict") + data_source_detail_dict: Any = None + dataset_process_rule_id: str | None = None + name: str + created_from: str | None = None + created_by: str | None = None + created_at: int | None = None + tokens: int | None = None + indexing_status: str | None = None + error: str | None = None + enabled: bool | None = None + disabled_at: int | None = None + disabled_by: str | None = None + archived: bool | None = None + display_status: str | None = None + word_count: int | None = None + hit_count: int | None = None + doc_form: str | None = None + doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details") + summary_index_status: str | None = None + need_summary: bool | None = None + + @field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + @field_validator("doc_metadata", mode="before") + @classmethod + def _normalize_doc_metadata(cls, value: Any) -> list[Any]: + if value is None: + return [] + return value + + @field_validator("created_at", "disabled_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentWithSegmentsResponse(DocumentResponse): + process_rule_dict: Any = None + completed_segments: int | None = None + total_segments: int | None = None + + +class DatasetAndDocumentResponse(ResponseModel): + dataset: DatasetResponse + documents: list[DocumentResponse] + batch: str class DocumentRetryPayload(BaseModel): @@ -107,6 +179,11 @@ class GenerateSummaryPayload(BaseModel): document_list: list[str] +class DocumentMetadataUpdatePayload(BaseModel): + doc_type: str | None = None + doc_metadata: Any = None + + class DocumentDatasetListParam(BaseModel): page: int = Field(1, title="Page", description="Page number.") limit: int = Field(20, title="Limit", description="Page size.") @@ -124,7 +201,13 @@ register_schema_models( DocumentRetryPayload, DocumentRenamePayload, GenerateSummaryPayload, + DocumentMetadataUpdatePayload, DocumentBatchDownloadZipPayload, + DatasetResponse, + DocumentMetadataResponse, + DocumentResponse, + DocumentWithSegmentsResponse, + DatasetAndDocumentResponse, ) @@ -357,10 +440,10 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) + @console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) @@ -398,7 +481,9 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return {"dataset": dataset, "documents": documents, "batch": batch} + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @setup_required @login_required @@ -426,12 +511,13 @@ class DatasetInitApi(Resource): @console_ns.doc("init_dataset") @console_ns.doc(description="Initialize dataset with documents") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) - @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) + @console_ns.response( + 201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__] + ) @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def post(self): @@ -479,9 +565,9 @@ class DatasetInitApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = {"dataset": dataset, "documents": documents, "batch": batch} - - return response + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/datasets//documents//indexing-estimate") @@ -988,15 +1074,7 @@ class DocumentMetadataApi(DocumentResource): @console_ns.doc("update_document_metadata") @console_ns.doc(description="Update document metadata") @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) - @console_ns.expect( - console_ns.model( - "UpdateDocumentMetadataRequest", - { - "doc_type": fields.String(description="Document type"), - "doc_metadata": fields.Raw(description="Document metadata"), - }, - ) - ) + @console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__]) @console_ns.response(200, "Document metadata updated successfully") @console_ns.response(404, "Document not found") @console_ns.response(403, "Permission denied") @@ -1009,10 +1087,10 @@ class DocumentMetadataApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - req_data = request.get_json() + req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {}) - doc_type = req_data.get("doc_type") - doc_metadata = req_data.get("doc_metadata") + doc_type = req_data.doc_type + doc_metadata = req_data.doc_metadata # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: @@ -1194,7 +1272,7 @@ class DocumentRenameApi(DocumentResource): @setup_required @login_required @account_initialization_required - @marshal_with(document_model) + @console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__]) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator @@ -1212,7 +1290,7 @@ class DocumentRenameApi(DocumentResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return document + return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json") @console_ns.route("/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 354c299bef..2647bb1f5a 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -2,7 +2,6 @@ import uuid from flask import request from flask_restx import Resource, marshal -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import String, cast, func, or_, select from sqlalchemy.dialects.postgresql import JSONB @@ -32,6 +31,7 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 8fb3699849..699fa599c8 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -2,7 +2,6 @@ import logging from typing import Any from flask_restx import marshal -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -21,6 +20,7 @@ from core.errors.error import ( QuotaExceededError, ) from fields.hit_testing_fields import hit_testing_record_fields +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index bdf83b991e..fd0a8b33bc 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -2,8 +2,6 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound @@ -12,6 +10,8 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 3549f9542d..b31d73f27d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -4,7 +4,6 @@ from typing import Any, NoReturn from flask import Response, request from flask_restx import Resource, marshal, marshal_with -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden @@ -28,6 +27,7 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type +from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index a8077d9eb0..ee146e8287 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,7 +4,6 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -41,6 +40,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from factories import variable_factory +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index a37077af42..ab660d9dc3 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,7 +1,6 @@ import logging from flask import request -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import services @@ -20,6 +19,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index eacd7332fe..ccdccceaa6 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -2,7 +2,6 @@ import logging from typing import Any, Literal from uuid import UUID -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,6 +25,7 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 64d55d7ca3..209667d1d0 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -2,7 +2,6 @@ import logging from typing import Literal from flask import request -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -25,6 +24,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.enums import FeedbackRating diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 0a3595454a..1456301a24 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -3,8 +3,6 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -61,6 +59,8 @@ from fields.workflow_fields import ( workflow_fields, workflow_partial_fields, ) +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index da88de6776..438cce4fd8 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,7 +1,5 @@ import logging -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError from controllers.common.controller_schemas import WorkflowRunPayload @@ -23,6 +21,8 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 551c86fd82..2a46d2250a 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,7 +2,6 @@ import urllib.parse import httpx from flask_restx import Resource -from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field import services @@ -16,6 +15,7 @@ from controllers.console import console_ns from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from services.file_service import FileService diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 3fdcbc4710..764f488755 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,8 +1,8 @@ from flask_restx import Resource, fields -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index b6b9deb1f9..f45b72f390 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -2,13 +2,13 @@ from typing import Any from flask import request from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index e4cfca9fa4..2a6f37aec8 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,12 +1,12 @@ from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index cbb9677309..4b10561fdb 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -3,13 +3,13 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index f8f95304f0..b2d07ff8f9 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -3,14 +3,14 @@ from typing import Any, cast from flask import request from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index aa674a63b3..b3e344ccea 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -4,7 +4,6 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -15,6 +14,7 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c9956501e2..471594f349 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -5,7 +5,6 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden @@ -28,6 +27,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 7a28a09861..d11b66244f 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -3,7 +3,6 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, model_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden @@ -16,6 +15,7 @@ from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_user, login_required from models.account import Account from models.provider_ids import TriggerProviderID diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 83c8fa02fe..72cab3de73 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,5 +1,4 @@ from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns @@ -30,6 +29,7 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.signature import get_signed_file_url_for_plugin +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index a5846e2815..2f309262cb 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel): def get_user(tenant_id: str, user_id: str | None) -> EndUser: """ - Get current user + Get current user. NOTE: user_id is not trusted, it could be maliciously set to any value. - As a result, it could only be considered as an end user id. + As a result, it could only be considered as an end user id. Even when a + concrete end-user ID is supplied, lookups must stay tenant-scoped so one + tenant cannot bind another tenant's user record into the plugin request + context. """ if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID @@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: .limit(1) ) else: - user_model = session.get(EndUser, user_id) + user_model = session.scalar( + select(EndUser) + .where( + EndUser.id == user_id, + EndUser.tenant_id == tenant_id, + ) + .limit(1) + ) if not user_model: user_model = EndUser( diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 8066f198bb..f652bbc581 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -2,7 +2,6 @@ from typing import Any, Union from flask import Response from flask_restx import Resource -from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import BaseModel, Field, ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -12,6 +11,7 @@ from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity, VariableEntityType from libs import helper from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 907dd1b06d..e818573b8f 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -2,7 +2,6 @@ import logging from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import services @@ -22,6 +21,7 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 3142e5118e..31f2797d66 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,7 +4,6 @@ from uuid import UUID from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -29,6 +28,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 50851aea08..c4353ca7b8 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -3,7 +3,6 @@ from typing import Any, Literal from flask import request from flask_restx import Resource -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field, TypeAdapter, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, NotFound @@ -22,6 +21,7 @@ from fields.conversation_fields import ( ConversationInfiniteScrollPagination, SimpleConversation, ) +from graphon.variables.types import SegmentType from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index fd954be6b1..76519cad0a 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,7 +2,6 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import Forbidden, NotFound @@ -19,6 +18,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage from core.rag.index_processor.constant.index_type import IndexTechniqueType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 971b63577c..5992fa7410 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -2,7 +2,6 @@ from typing import Any from flask import request from flask_restx import marshal -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select from werkzeug.exceptions import NotFound @@ -23,6 +22,7 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index c0a6cb0a76..5ac65fc4e6 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,9 +1,9 @@ from flask_login import current_user from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token +from graphon.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 0ef4471018..3ad595f1f4 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -2,7 +2,6 @@ import logging from flask import request from flask_restx import fields, marshal_with -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import field_validator from werkzeug.exceptions import InternalServerError @@ -22,6 +21,7 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index e37f9af5f0..0528184d79 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,7 +1,6 @@ import logging from typing import Any, Literal -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,6 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 39afdd843f..07ecf8035b 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,7 +2,6 @@ import logging from typing import Literal from flask import request -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.enums import FeedbackRating from models.model import AppMode diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 38aeccc642..fe31e9d4ac 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,6 @@ import urllib.parse import httpx -from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, HttpUrl import services @@ -14,6 +13,7 @@ from controllers.common.errors import ( from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from services.file_service import FileService from ..common.schema import register_schema_models diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 796e090976..98211193a0 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,7 +1,5 @@ import logging -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError from controllers.common.controller_schemas import WorkflowRunPayload @@ -24,6 +22,8 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 06c746990d..790602ef5d 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -4,20 +4,6 @@ import uuid from decimal import Decimal from typing import Union, cast -from graphon.file import file_manager -from graphon.model_runtime.entities import ( - AssistantPromptMessage, - LLMUsage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import func, select from core.agent.entities import AgentEntity, AgentToolEntity @@ -43,6 +29,20 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index f07ac64498..0bc93ad34d 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -4,15 +4,6 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence from typing import Any, TypedDict -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) - from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError @@ -24,6 +15,14 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 2b2e26987e..a2186be100 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,5 +1,6 @@ import json +from core.agent.cot_agent_runner import CotAgentRunner from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -11,8 +12,6 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.utils.encoders import jsonable_encoder -from core.agent.cot_agent_runner import CotAgentRunner - class CotChatAgentRunner(CotAgentRunner): def _organize_system_prompt(self) -> SystemPromptMessage: diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index d4c52a8eb1..51a30998ae 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,5 +1,6 @@ import json +from core.agent.cot_agent_runner import CotAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -8,8 +9,6 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.utils.encoders import jsonable_encoder -from core.agent.cot_agent_runner import CotAgentRunner - class CotCompletionAgentRunner(CotAgentRunner): def _organize_instruction_prompt(self) -> str: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index fdffde85d0..29de0b8b1c 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,6 +4,13 @@ from collections.abc import Generator from copy import deepcopy from typing import Any, Union +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -19,14 +26,6 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes - -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.errors import AgentMaxIterationError -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine from models.model import Message logger = logging.getLogger(__name__) @@ -300,7 +299,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): # update prompt tool for prompt_tool in prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) + tool_instance = tool_instances.get(prompt_tool.name) + if tool_instance: + self.update_prompt_message_tool(tool_instance, prompt_tool) iteration_step += 1 diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 8cccd2be6d..f341ca5a0b 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -3,9 +3,8 @@ import re from collections.abc import Generator from typing import Any, Union -from graphon.model_runtime.entities.llm_entities import LLMResultChunk - from core.agent.entities import AgentScratchpadUnit +from graphon.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index b7dd55632e..dbd7527fc6 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,14 +1,13 @@ from typing import cast -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 9d980e5ca3..02498c23e1 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,10 +1,9 @@ from collections.abc import Mapping from typing import Any -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType - from core.app.app_config.entities import ModelConfigEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 57c6d1c496..4c07445df3 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,7 +1,5 @@ from typing import Any -from graphon.model_runtime.entities.message_entities import PromptMessageRole - from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -9,6 +7,7 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 819aca864c..53563dc5da 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,14 +1,14 @@ from enum import StrEnum, auto from typing import Any, Literal -from graphon.file import FileUploadConfig -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.entities import MetadataFilteringCondition +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 959c3868b4..8f20ef2ff9 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,9 +1,8 @@ from collections.abc import Mapping from typing import Any -from graphon.file import FileUploadConfig - from constants import DEFAULT_FILE_NUMBER_LIMITS +from graphon.file import FileUploadConfig class FileUploadConfigManager: diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 62e0c31d1a..13ace32fd6 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,8 +1,7 @@ import re -from graphon.variables.input_entities import VariableEntity - from core.app.app_config.entities import RagPipelineVariableEntity +from graphon.variables.input_entities import VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 985ded0f74..9e64b471cb 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -18,11 +18,6 @@ from constants import UUID_NIL if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader - from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner @@ -48,6 +43,10 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 7b4cb98bd4..4e57b4dedc 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,12 +3,6 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels import RedisChannel -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import Variable from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -43,6 +37,12 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from models import Workflow from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 5872f6b264..5cdc477028 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, In from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index a20d3f3c38..09ddce327e 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,9 +1,6 @@ import logging from typing import cast -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.cot_chat_agent_runner import CotChatAgentRunner @@ -19,6 +16,9 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 406d07927e..d5edfaeb25 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -3,11 +3,10 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping from typing import Any, Union -from graphon.model_runtime.errors.invoke import InvokeError - from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 20bf81aeec..d1771452c5 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -7,7 +7,6 @@ from enum import IntEnum, auto from typing import Any from cachetools import TTLCache, cachedmethod -from graphon.runtime import GraphRuntimeState from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta @@ -22,6 +21,7 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from extensions.ext_redis import redis_client +from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 4aebc0cb30..1251b397e2 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,17 +5,6 @@ from collections.abc import Generator, Mapping, Sequence from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError - from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( @@ -41,6 +30,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 891dcece73..58afefe296 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, copy_current_request_context, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeF from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account from models.model import App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 050f763e95..077c5239f3 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,8 +1,6 @@ import logging from typing import cast -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -18,6 +16,8 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index ab277857fe..2a90fbdad0 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,9 +4,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from graphon.runtime import GraphRuntimeState - from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.runtime import GraphRuntimeState if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index a515531616..bd685d5189 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -6,19 +6,6 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, TypedDict, Union -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import FILE_MODEL_IDENTITY, File -from graphon.runtime import GraphRuntimeState -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.variables import Variable -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from sqlalchemy.orm import Session @@ -68,6 +55,19 @@ from core.workflow.human_input_forms import load_form_tokens_by_form_id from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.human_input import HumanInputForm diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 61339b316a..423bfdac51 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, copy_current_request_context, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from sqlalchemy import select @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, I from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account, App, EndUser, Message from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index b216f7cf7b..6bb1ecdcb1 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,8 +1,6 @@ import logging from typing import cast -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -16,6 +14,8 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 83c74b86e5..4b2f17189b 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -10,8 +10,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, cast, overload from flask import Flask, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -43,6 +41,8 @@ from core.repositories.factory import ( WorkflowNodeExecutionRepository, ) from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 36daaf09e9..2ee0ae27eb 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -2,12 +2,6 @@ import logging import time from typing import cast -from graphon.enums import WorkflowType -from graphon.graph import Graph -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -26,6 +20,12 @@ from core.workflow.system_variables import build_bootstrap_variables, build_syst from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from models.dataset import Document, Pipeline from models.model import EndUser from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index ba070ffa94..5a1e7e117f 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -8,10 +8,6 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, overload from flask import Flask, current_app -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -38,6 +34,10 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models.account import Account from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 2cb8088971..cfb9208486 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -3,12 +3,6 @@ import time from collections.abc import Sequence from typing import cast -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels import RedisChannel -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader - from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -21,6 +15,11 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader from libs.datetime_utils import naive_utc_now from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 96387133b1..15645add57 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,9 +4,6 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Union -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -61,6 +58,9 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from models import Account from models.enums import CreatorUserRole from models.model import EndUser diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 437432611d..047b54c86c 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,39 +3,6 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph import Graph -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from pydantic import ValidationError from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -82,6 +49,39 @@ from core.workflow.system_variables import ( from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph import Graph +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( + GraphEngineEvent, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, + NodeRunIterationFailedEvent, + NodeRunIterationNextEvent, + NodeRunIterationStartedEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunLoopNextEvent, + NodeRunLoopStartedEvent, + NodeRunLoopSucceededEvent, + NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index a3fb7b4c5d..09992f4bbf 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -2,13 +2,13 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import TYPE_CHECKING, Any -from graphon.file import File, FileUploadConfig -from graphon.model_runtime.entities.model_entities import AIModelEntity from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 482f995d8e..221b7fb058 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -3,14 +3,14 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import PauseReason -from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities import RetrievalSourceMetadata +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReason +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 88faf235d1..6e4ca69cf0 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -2,14 +2,14 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities import RetrievalSourceMetadata +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index d2d2fea4fb..d59f5125e3 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,9 +1,8 @@ import logging -from graphon.model_runtime.entities.message_entities import PromptMessage - from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation +from graphon.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index c027f42788..9811f9f830 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,14 +1,14 @@ from dataclasses import dataclass from typing import Annotated, Literal, Self -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from pydantic import BaseModel, Field from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index 8c8daf8712..bb9fc1b6fa 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -3,10 +3,10 @@ import uuid from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore + from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent - from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index 77c7bec67e..b60fe82ffe 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -2,12 +2,12 @@ import logging from datetime import UTC, datetime from typing import Any, ClassVar -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from pydantic import TypeAdapter from core.db.session_factory import session_factory from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index 278d0cb30b..c49c4eb0ac 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,16 +2,15 @@ from __future__ import annotations from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.nodes.llm.entities import ModelConfig -from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from graphon.nodes.llm.protocols import CredentialsProvider - from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 0bb10190c4..b6039e1e4e 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,4 +1,3 @@ -from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update from sqlalchemy.orm import sessionmaker @@ -8,6 +7,7 @@ from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMUsage from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 10b9c36d3e..9e688589db 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,7 +1,6 @@ import logging import time -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from sqlalchemy.orm import Session @@ -18,6 +17,7 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index c577ce0754..4a7918032e 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -7,17 +7,16 @@ This layer centralizes model-quota deduction outside node implementations. import logging from typing import TYPE_CHECKING, cast, final, override +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.llm import deduct_llm_quota, ensure_llm_quota_available +from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance from graphon.enums import BuiltinNodeTypes from graphon.graph_engine.entities.commands import AbortCommand, CommandType from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent from graphon.nodes.base.node import Node -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.llm import deduct_llm_quota, ensure_llm_quota_available -from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance - if TYPE_CHECKING: from graphon.nodes.llm.node import LLMNode from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index ada065a943..87f005a250 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -14,6 +14,13 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Union +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from graphon.entities import WorkflowExecution, WorkflowNodeExecution from graphon.enums import ( WorkflowExecutionStatus, @@ -38,14 +45,6 @@ from graphon.graph_events import ( NodeRunSucceededEvent, ) from graphon.node_events import NodeRunResult - -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from libs.datetime_utils import naive_utc_now diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 3d8a7a54f3..9e3c187210 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -6,9 +6,6 @@ import re import threading from collections.abc import Iterable -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelType - from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentMessageEvent, @@ -18,6 +15,8 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType class AudioTrunk: diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index a5297fa33a..dc831e5cac 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -3,9 +3,6 @@ from collections.abc import Generator from threading import Lock from typing import Any, cast -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from sqlalchemy import select import contexts @@ -31,6 +28,9 @@ from core.plugin.impl.datasource import PluginDatasourceManager from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam from factories import file_factory +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 9c22d5e67c..352e6bfd49 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ from typing import Any, Literal, TypedDict -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.tools.entities.common_entities import I18nObject, I18nObjectDict +from graphon.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index c012e128f4..6a3f9e684a 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -2,11 +2,10 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type -from graphon.file import File, FileTransferMethod, FileType - from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index d304c982cd..04ae193396 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -3,9 +3,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any, TypeAlias -from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field +from graphon.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index a440829b46..bfa4f56915 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -6,7 +6,6 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse -from graphon.file import helpers as file_helpers from pydantic import BaseModel from configs import dify_config @@ -16,6 +15,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from graphon.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 84d95c38c6..e99a131500 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,10 +1,11 @@ from collections.abc import Sequence from enum import StrEnum, auto +from pydantic import BaseModel, ConfigDict + from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel from graphon.model_runtime.entities.provider_entities import ProviderEntity -from pydantic import BaseModel, ConfigDict class ModelStatus(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index d07f6f913a..6bbf163c9d 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -8,16 +8,6 @@ from collections.abc import Iterator, Sequence from json import JSONDecodeError from typing import Any -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime.runtime import ModelRuntime from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -34,6 +24,16 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel): else [], ) - def validate_provider_credentials( - self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None - ): + def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""): """ Validate custom credentials. :param credentials: provider credentials :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate - :param session: optional database session :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.id == credential_id, ) - credential_record = s.execute(stmt).scalar_one_or_none() - # fix origin data + credential_record = session.execute(stmt).scalar_one_or_none() if credential_record and credential_record.encrypted_config: if not credential_record.encrypted_config.startswith("{"): original_credentials = {"openai_api_key": credential_record.encrypted_config} @@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # encrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.provider_credentials_validate( + provider=self.provider.provider, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables: + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def _generate_provider_credential_name(self, session) -> str: """ @@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel): :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: - if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): + if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session): raise ValueError(f"Credential with name '{credential_name}' already exists.") else: - credential_name = self._generate_provider_credential_name(session) + credential_name = self._generate_provider_credential_name(pre_session) - credentials = self.validate_provider_credentials(credentials=credentials, session=session) + credentials = self.validate_provider_credentials(credentials=credentials) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) try: new_record = ProviderCredential( @@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel): session.flush() if not provider_record: - # If provider record does not exist, create it provider_record = Provider( tenant_id=self.tenant_id, provider_name=self.provider.provider, @@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel): :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_provider_credential_name_exists( - credential_name=credential_name, session=session, exclude_id=credential_id + credential_name=credential_name, session=pre_session, exclude_id=credential_id ): raise ValueError(f"Credential with name '{credential_name}' already exists.") - credentials = self.validate_provider_credentials( - credentials=credentials, credential_id=credential_id, session=session - ) + credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, @@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel): ProviderCredential.provider_name.in_(self._get_provider_names()), ) - # Get the credential record to update credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: @@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel): model: str, credentials: dict[str, Any], credential_id: str = "", - session: Session | None = None, ): """ Validate custom model credentials. @@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel): :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, @@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type, ) - credential_record = s.execute(stmt).scalar_one_or_none() + credential_record = session.execute(stmt).scalar_one_or_none() original_credentials = ( json.loads(credential_record.encrypted_config) if credential_record and credential_record.encrypted_config @@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # decrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.model_credentials_validate( + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables: + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def create_custom_model_credential( self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None @@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel): :param credentials: model credentials dict :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: if self._check_custom_model_credential_name_exists( - model=model, model_type=model_type, credential_name=credential_name, session=session + model=model, model_type=model_type, credential_name=credential_name, session=pre_session ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") else: credential_name = self._generate_custom_model_credential_name( - model=model, model_type=model_type, session=session + model=model, model_type=model_type, session=pre_session ) - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, model=model, credentials=credentials, session=session - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, model=model, credentials=credentials + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) try: @@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel): session.add(credential) session.flush() - # save provider model if not provider_model_record: provider_model_record = ProviderModel( tenant_id=self.tenant_id, @@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel): :param credential_id: credential id :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, - session=session, + session=pre_session, exclude_id=credential_id, ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, - model=model, - credentials=credentials, - credential_id=credential_id, - session=session, - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, + model=model, + credentials=credentials, + credential_id=credential_id, + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) stmt = select(ProviderModelCredential).where( @@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel): raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 95431c0e01..72b29c2277 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -3,7 +3,6 @@ from __future__ import annotations from enum import StrEnum, auto from typing import Any, Union -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import ( @@ -13,6 +12,7 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 35bfcfb6a5..951e065b2c 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,7 +4,6 @@ from threading import Lock from typing import Any import httpx -from graphon.nodes.code.entities import CodeLanguage from pydantic import BaseModel from yarl import URL @@ -14,6 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client +from graphon.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index a1e782a094..dc37a36943 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -2,14 +2,13 @@ import logging import secrets from typing import cast -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeBadRequestError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index f8f56e12d2..8bcb899b23 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,12 +1,12 @@ from typing import Any from flask import Flask -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel +from graphon.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index a8ad7c9179..d2e375626f 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -5,6 +5,11 @@ from enum import StrEnum from typing import Any, Literal, cast, overload import json_repair +from pydantic import TypeAdapter, ValidationError + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT +from core.model_manager import ModelInstance from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import ( LLMResult, @@ -21,11 +26,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule -from pydantic import TypeAdapter, ValidationError - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT -from core.model_manager import ModelInstance class ResponseFormat(StrEnum): diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 72171d1536..884610ca82 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -3,12 +3,11 @@ import logging from collections.abc import Mapping from typing import Any, NotRequired, TypedDict, cast -from graphon.variables.input_entities import VariableEntity, VariableEntityType - from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 7e35044176..7b5a7635f1 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -4,11 +4,11 @@ from contextlib import AbstractContextManager import httpx import httpx_sse -from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError +from graphon.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 5809d6f74a..d840ee213c 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,5 +1,14 @@ from collections.abc import Sequence +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController +from core.model_manager import ModelInstance +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from factories import file_factory from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -10,15 +19,6 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from sqlalchemy import select -from sqlalchemy.orm import sessionmaker - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.file_access import DatabaseFileAccessController -from core.model_manager import ModelInstance -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from extensions.ext_database import db -from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 36beb55d7f..d8d8dfedd8 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -2,6 +2,15 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload +from configs import dify_config +from core.entities import PluginCredentialType +from core.entities.embedding_type import EmbeddingInputType +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import ModelLoadBalancingConfiguration +from core.errors.error import ProviderTokenNotInitError +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from core.provider_manager import ProviderManager +from extensions.ext_redis import redis_client from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -15,16 +24,6 @@ from graphon.model_runtime.model_providers.__base.rerank_model import RerankMode from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from graphon.model_runtime.model_providers.__base.tts_model import TTSModel - -from configs import dify_config -from core.entities import PluginCredentialType -from core.entities.embedding_type import EmbeddingInputType -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import ModelLoadBalancingConfiguration -from core.errors.error import ProviderTokenNotInitError -from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager -from core.provider_manager import ProviderManager -from extensions.ext_redis import redis_client from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index fda00ac3b9..d78ce90aa1 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -1,8 +1,8 @@ from enum import StrEnum -from pydantic import BaseModel, ValidationInfo, field_validator +from pydantic import BaseModel -from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path +from core.ops.utils import validate_project_name, validate_url class TracingProviderEnum(StrEnum): @@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel): return validate_project_name(v, default_name) -class ArizeConfig(BaseTracingConfig): - """ - Model class for Arize tracing config. - """ - - api_key: str | None = None - space_id: str | None = None - project: str | None = None - endpoint: str = "https://otlp.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://otlp.arize.com") - - -class PhoenixConfig(BaseTracingConfig): - """ - Model class for Phoenix tracing config. - """ - - api_key: str | None = None - project: str | None = None - endpoint: str = "https://app.phoenix.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://app.phoenix.arize.com") - - -class LangfuseConfig(BaseTracingConfig): - """ - Model class for Langfuse tracing config. - """ - - public_key: str - secret_key: str - host: str = "https://api.langfuse.com" - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://api.langfuse.com") - - -class LangSmithConfig(BaseTracingConfig): - """ - Model class for Langsmith tracing config. - """ - - api_key: str - project: str - endpoint: str = "https://api.smith.langchain.com" - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # LangSmith only allows HTTPS - return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) - - -class OpikConfig(BaseTracingConfig): - """ - Model class for Opik tracing config. - """ - - api_key: str | None = None - project: str | None = None - workspace: str | None = None - url: str = "https://www.comet.com/opik/api/" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "Default Project") - - @field_validator("url") - @classmethod - def url_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") - - -class WeaveConfig(BaseTracingConfig): - """ - Model class for Weave tracing config. - """ - - api_key: str - entity: str | None = None - project: str - endpoint: str = "https://trace.wandb.ai" - host: str | None = None - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # Weave only allows HTTPS for endpoint - return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - if v is not None and v.strip() != "": - return validate_url(v, v, allowed_schemes=("https", "http")) - return v - - -class AliyunConfig(BaseTracingConfig): - """ - Model class for Aliyun tracing config. - """ - - app_name: str = "dify_app" - license_key: str - endpoint: str - - @field_validator("app_name") - @classmethod - def app_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - @field_validator("license_key") - @classmethod - def license_key_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("License key cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # aliyun uses two URL formats, which may include a URL path - return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") - - -class TencentConfig(BaseTracingConfig): - """ - Tencent APM tracing config - """ - - token: str - endpoint: str - service_name: str - - @field_validator("token") - @classmethod - def token_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("Token cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") - - @field_validator("service_name") - @classmethod - def service_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - -class MLflowConfig(BaseTracingConfig): - """ - Model class for MLflow tracing config. - """ - - tracking_uri: str = "http://localhost:5000" - experiment_id: str = "0" # Default experiment id in MLflow is 0 - username: str | None = None - password: str | None = None - - @field_validator("tracking_uri") - @classmethod - def tracking_uri_validator(cls, v, info: ValidationInfo): - if isinstance(v, str) and v.startswith("databricks"): - raise ValueError( - "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." - ) - return validate_url_with_path(v, "http://localhost:5000") - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - -class DatabricksConfig(BaseTracingConfig): - """ - Model class for Databricks (Databricks-managed MLflow) tracing config. - """ - - experiment_id: str - host: str - client_id: str | None = None - client_secret: str | None = None - personal_access_token: str | None = None - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - OPS_FILE_PATH = "ops_trace/" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index cd63951537..e7ba6e502b 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict): class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]): def __getitem__(self, provider: str) -> TracingProviderConfigEntry: - match provider: - case TracingProviderEnum.LANGFUSE: - from core.ops.entities.config_entity import LangfuseConfig - from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace + try: + match provider: + case TracingProviderEnum.LANGFUSE: + from dify_trace_langfuse.config import LangfuseConfig + from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace - return { - "config_class": LangfuseConfig, - "secret_keys": ["public_key", "secret_key"], - "other_keys": ["host", "project_key"], - "trace_instance": LangFuseDataTrace, - } + return { + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, + } - case TracingProviderEnum.LANGSMITH: - from core.ops.entities.config_entity import LangSmithConfig - from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace + case TracingProviderEnum.LANGSMITH: + from dify_trace_langsmith.config import LangSmithConfig + from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace - return { - "config_class": LangSmithConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": LangSmithDataTrace, - } + return { + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + } - case TracingProviderEnum.OPIK: - from core.ops.entities.config_entity import OpikConfig - from core.ops.opik_trace.opik_trace import OpikDataTrace + case TracingProviderEnum.OPIK: + from dify_trace_opik.config import OpikConfig + from dify_trace_opik.opik_trace import OpikDataTrace - return { - "config_class": OpikConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "url", "workspace"], - "trace_instance": OpikDataTrace, - } + return { + "config_class": OpikConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "url", "workspace"], + "trace_instance": OpikDataTrace, + } - case TracingProviderEnum.WEAVE: - from core.ops.entities.config_entity import WeaveConfig - from core.ops.weave_trace.weave_trace import WeaveDataTrace + case TracingProviderEnum.WEAVE: + from dify_trace_weave.config import WeaveConfig + from dify_trace_weave.weave_trace import WeaveDataTrace - return { - "config_class": WeaveConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "entity", "endpoint", "host"], - "trace_instance": WeaveDataTrace, - } - case TracingProviderEnum.ARIZE: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import ArizeConfig + return { + "config_class": WeaveConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "entity", "endpoint", "host"], + "trace_instance": WeaveDataTrace, + } + case TracingProviderEnum.ARIZE: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import ArizeConfig - return { - "config_class": ArizeConfig, - "secret_keys": ["api_key", "space_id"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.PHOENIX: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import PhoenixConfig + return { + "config_class": ArizeConfig, + "secret_keys": ["api_key", "space_id"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.PHOENIX: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import PhoenixConfig - return { - "config_class": PhoenixConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.ALIYUN: - from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace - from core.ops.entities.config_entity import AliyunConfig + return { + "config_class": PhoenixConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.ALIYUN: + from dify_trace_aliyun.aliyun_trace import AliyunDataTrace + from dify_trace_aliyun.config import AliyunConfig - return { - "config_class": AliyunConfig, - "secret_keys": ["license_key"], - "other_keys": ["endpoint", "app_name"], - "trace_instance": AliyunDataTrace, - } - case TracingProviderEnum.MLFLOW: - from core.ops.entities.config_entity import MLflowConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": AliyunConfig, + "secret_keys": ["license_key"], + "other_keys": ["endpoint", "app_name"], + "trace_instance": AliyunDataTrace, + } + case TracingProviderEnum.MLFLOW: + from dify_trace_mlflow.config import MLflowConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": MLflowConfig, - "secret_keys": ["password"], - "other_keys": ["tracking_uri", "experiment_id", "username"], - "trace_instance": MLflowDataTrace, - } - case TracingProviderEnum.DATABRICKS: - from core.ops.entities.config_entity import DatabricksConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": MLflowConfig, + "secret_keys": ["password"], + "other_keys": ["tracking_uri", "experiment_id", "username"], + "trace_instance": MLflowDataTrace, + } + case TracingProviderEnum.DATABRICKS: + from dify_trace_mlflow.config import DatabricksConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": DatabricksConfig, - "secret_keys": ["personal_access_token", "client_secret"], - "other_keys": ["host", "client_id", "experiment_id"], - "trace_instance": MLflowDataTrace, - } + return { + "config_class": DatabricksConfig, + "secret_keys": ["personal_access_token", "client_secret"], + "other_keys": ["host", "client_id", "experiment_id"], + "trace_instance": MLflowDataTrace, + } - case TracingProviderEnum.TENCENT: - from core.ops.entities.config_entity import TencentConfig - from core.ops.tencent_trace.tencent_trace import TencentDataTrace + case TracingProviderEnum.TENCENT: + from dify_trace_tencent.config import TencentConfig + from dify_trace_tencent.tencent_trace import TencentDataTrace - return { - "config_class": TencentConfig, - "secret_keys": ["token"], - "other_keys": ["endpoint", "service_name"], - "trace_instance": TencentDataTrace, - } + return { + "config_class": TencentConfig, + "secret_keys": ["token"], + "other_keys": ["endpoint", "service_name"], + "trace_instance": TencentDataTrace, + } - case _: - raise KeyError(f"Unsupported tracing provider: {provider}") + case _: + raise KeyError(f"Unsupported tracing provider: {provider}") + except ImportError: + raise ImportError(f"Provider {provider} is not installed.") provider_config_map = OpsTraceProviderConfigMap() diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index a4b24ff849..c92438960a 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -3,20 +3,6 @@ from binascii import hexlify, unhexlify from collections.abc import Generator from typing import Any -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelType - from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelManager @@ -33,6 +19,19 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Tenant diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 9478997494..9550e49992 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,3 +1,4 @@ +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from graphon.enums import BuiltinNodeTypes from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig from graphon.nodes.parameter_extractor.entities import ( @@ -8,8 +9,6 @@ from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) - -from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from services.workflow_service import WorkflowService diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 4d28032a57..89e0e8881c 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -3,7 +3,6 @@ from collections.abc import Mapping from enum import StrEnum, auto from typing import Any -from graphon.model_runtime.entities.provider_entities import ProviderEntity from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -14,6 +13,7 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index e0ddb746c7..257638ad77 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -6,8 +6,6 @@ from datetime import datetime from enum import StrEnum from typing import Any -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin @@ -18,6 +16,8 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginDaemonBasicResponse[T: BaseModel | dict | list | bool | str](BaseModel): diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 4a85952dcd..1474883204 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -4,6 +4,10 @@ from collections.abc import Mapping from typing import Any, Literal from flask import Response +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from core.entities.provider_entities import BasicProviderConfig +from core.plugin.utils.http_parser import deserialize_response from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -21,10 +25,6 @@ from graphon.nodes.parameter_extractor.entities import ( from graphon.nodes.question_classifier.entities import ( ClassConfig, ) -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from core.entities.provider_entities import BasicProviderConfig -from core.plugin.utils.http_parser import deserialize_response class InvokeCredentials(BaseModel): diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7f36560b49..9ee8469892 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -5,14 +5,6 @@ from collections.abc import Callable, Generator from typing import Any, cast import httpx -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from yarl import URL @@ -37,6 +29,14 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 703af63f7c..47608bdfa6 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -2,13 +2,6 @@ import binascii from collections.abc import Generator, Sequence from typing import IO, Any -from graphon.model_runtime.entities.llm_entities import LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.utils.encoders import jsonable_encoder - from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDaemonInnerError, @@ -20,6 +13,12 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 90350f8400..12d8e282b2 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,8 +1,7 @@ from typing import Any -from graphon.file import File - from core.tools.entities.tool_entities import ToolSelector +from graphon.file import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 19b5e9223a..24e05ef865 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,6 +1,13 @@ from collections.abc import Mapping, Sequence from typing import cast +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -13,14 +20,6 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.runtime import VariablePool -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser - class AdvancedPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 9be70199b7..8f1d51f08a 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,17 +1,16 @@ from typing import cast -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.prompt_transform import PromptTransform +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 4539ae9f11..6ff2f44cdc 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,12 +1,11 @@ from typing import Any -from graphon.model_runtime.entities.message_entities import PromptMessage -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index dc8391a6a5..1665bdeb52 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -4,6 +4,12 @@ from collections.abc import Mapping, Sequence from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, TypedDict, cast +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import file_manager from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, @@ -13,13 +19,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) - -from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index dbda749925..ba76eb0c4e 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from typing import Any, cast +from core.prompt.simple_prompt_transform import ModelMode from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -11,8 +12,6 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, ) -from core.prompt.simple_prompt_transform import ModelMode - class PromptMessageUtil: @staticmethod diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 39ef31632e..c3bbe8fc09 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -6,14 +6,6 @@ from collections.abc import Sequence from json import JSONDecodeError from typing import TYPE_CHECKING, Any -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -41,6 +33,14 @@ from core.helper.position_helper import is_filtered from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ( LoadBalancingModelConfig, Provider, diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index f978e072f3..7e71d67ec0 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -4,7 +4,6 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, NotRequired, TypedDict from flask import Flask, current_app -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm import Session, load_only @@ -24,6 +23,7 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( ChildChunk, Dataset, @@ -195,6 +195,23 @@ class RetrievalService: ) return all_documents + @classmethod + def _filter_documents_by_vector_score_threshold( + cls, documents: list[Document], score_threshold: float | None + ) -> list[Document]: + """Keep documents whose stored retrieval score meets the threshold. + + Used when hybrid search skips early vector thresholding but no rerank + runner applies a threshold afterward (same rule as ``calculate_vector_score``). + """ + if score_threshold is None: + return documents + return [ + document + for document in documents + if document.metadata and document.metadata.get("score", 0) >= score_threshold + ] + @classmethod def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]: """Deduplicate documents in O(n) while preserving first-seen order. @@ -294,13 +311,20 @@ class RetrievalService: vector = Vector(dataset=dataset) documents = [] + # Hybrid search merges keyword / full-text / vector hits and then reranks + # (weighted fusion or reranking model). Applying the user score threshold at + # vector retrieval time uses embedding similarity, which is not comparable to + # reranked or fused scores and incorrectly drops high-quality chunks (#35233). + embedding_score_threshold = ( + 0.0 if retrieval_method == RetrievalMethod.HYBRID_SEARCH else score_threshold + ) if query_type == QueryType.TEXT_QUERY: documents.extend( vector.search_by_vector( query, search_type="similarity_score_threshold", top_k=top_k, - score_threshold=score_threshold, + score_threshold=embedding_score_threshold, filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) @@ -312,7 +336,7 @@ class RetrievalService: vector.search_by_file( file_id=query, top_k=top_k, - score_threshold=score_threshold, + score_threshold=embedding_score_threshold, filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) @@ -844,6 +868,10 @@ class RetrievalService: top_n=top_k, query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY, ) + if not data_post_processor.rerank_runner and score_threshold: + all_documents_item = self._filter_documents_by_vector_score_threshold( + all_documents_item, score_threshold + ) all_documents.extend(all_documents_item) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index dddd5fc994..59d7f3c3c4 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -4,7 +4,6 @@ import time from abc import ABC, abstractmethod from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from configs import dify_config @@ -19,6 +18,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Whitelist from models.model import UploadFile diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 8e9ebdd17a..f4699f6869 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -3,13 +3,13 @@ from __future__ import annotations from collections.abc import Sequence from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, func, select from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 9f1c73ec88..4926f44f16 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -4,8 +4,6 @@ import pickle from typing import Any, cast import numpy as np -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -15,6 +13,8 @@ from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index a487c49053..f8242efe31 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -7,16 +7,6 @@ from typing import Any, TypedDict, cast logger = logging.getLogger(__name__) -from graphon.file import File, FileTransferMethod, FileType, file_manager -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -43,6 +33,16 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories.file_factory import build_from_mapping +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from libs import helper from models import UploadFile from models.account import Account diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 087736d0b0..4ebf095904 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any -from graphon.file import File from pydantic import BaseModel, Field +from graphon.file import File + class ChildDocument(BaseModel): """Class for storing a piece of text and associated metadata.""" diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index a8d37845a5..bce08f998f 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,8 +1,5 @@ import base64 -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult - from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType @@ -10,6 +7,8 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from models.model import UploadFile diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 49123e13d0..d0732b269a 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -2,7 +2,6 @@ import math from collections import Counter import numpy as np -from graphon.model_runtime.entities.model_entities import ModelType from core.model_manager import ModelManager from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -12,6 +11,7 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner +from graphon.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8ebc840b99..1453fe020b 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,11 +9,6 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app -from graphon.file import File, FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import and_, func, literal, or_, select, update from sqlalchemy.orm import sessionmaker @@ -69,6 +64,11 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( ) from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index dce7b6226c..e617a9660e 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,10 +1,9 @@ from typing import Union -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 3383c7f3bd..2581c354dd 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -7,10 +7,9 @@ import re from collections.abc import Collection from typing import Any, Literal -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer - from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter +from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index b07c63fdf0..e87d1cd6b2 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -7,11 +7,11 @@ providing improved performance by offloading database operations to background w import logging -from graphon.entities import WorkflowExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities import WorkflowExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index cdb3af01a8..2451563317 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -8,7 +8,6 @@ providing improved performance by offloading database operations to background w import logging from collections.abc import Sequence -from graphon.entities import WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -16,6 +15,7 @@ from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) +from graphon.entities import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index ce3ad15759..4e83e70799 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -9,11 +9,11 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Protocol -from graphon.entities import WorkflowExecution, WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index d74cc8f231..6be3902317 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -5,13 +5,13 @@ SQLAlchemy implementation of the WorkflowExecutionRepository. import json import logging -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 13e885672a..b036687bc9 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -10,10 +10,6 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any import psycopg2.errors -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError @@ -23,6 +19,10 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att from configs import dify_config from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 from models import ( diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index e539074303..95660ab93b 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -2,15 +2,14 @@ import io from collections.abc import Generator from typing import Any -from graphon.file import FileType -from graphon.file.file_manager import download -from graphon.model_runtime.entities.model_entities import ModelType - from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from graphon.file import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index f49c669fe0..ac3820f1ab 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -2,13 +2,12 @@ import io from collections.abc import Generator from typing import Any -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType - from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 14af63a962..d41503e1e6 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,12 +1,11 @@ from __future__ import annotations -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage - from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 0a2c37c563..168e5f4493 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -6,7 +6,6 @@ from typing import Any, Union from urllib.parse import urlencode import httpx -from graphon.file.file_manager import download from core.helper import ssrf_proxy from core.tools.__base.tool import Tool @@ -14,6 +13,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError +from graphon.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 410ec72baf..42a88c0003 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -2,7 +2,6 @@ from collections.abc import Mapping from datetime import datetime from typing import Any, Literal -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration @@ -10,6 +9,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from graphon.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index f6d09472b3..00fc8a8282 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -6,8 +6,6 @@ import logging from collections.abc import Generator, Mapping from typing import Any, cast -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata - from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import ( @@ -23,6 +21,7 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index d060fa8b49..3caacb8706 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -7,7 +7,6 @@ from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Union, cast -from graphon.file import FileTransferMethod, FileType from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom @@ -33,6 +32,7 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from graphon.file import FileTransferMethod, FileType from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index d8674b3af9..b3424cd9a5 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -9,7 +9,6 @@ from mimetypes import guess_extension, guess_type from uuid import uuid4 import httpx -from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from sqlalchemy import select from configs import dify_config @@ -17,6 +16,7 @@ from core.db.session_factory import session_factory from core.helper import ssrf_proxy from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from models.model import MessageFile from models.tools import ToolFile diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index be13d40f3e..f4588904d3 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -8,7 +8,6 @@ from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Protocol, cast import sqlalchemy as sa -from graphon.runtime import VariablePool from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.orm import Session @@ -29,14 +28,13 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from extensions.ext_database import db +from graphon.runtime import VariablePool from models.provider_ids import ToolProviderID from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: pass -from graphon.model_runtime.utils.encoders import jsonable_encoder - from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source @@ -62,6 +60,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 03e3c5918d..b6890b2611 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,7 +1,6 @@ import threading from flask import Flask, current_app -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select @@ -15,6 +14,7 @@ from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: DefaultRetrievalModelDict = { diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 81c85bc90d..79d0c114d4 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -9,11 +9,11 @@ from uuid import UUID import numpy as np import pytz -from graphon.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 8d6f83dc07..9e1d41cb39 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -8,6 +8,9 @@ import json from decimal import Decimal from typing import cast +from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType +from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -20,10 +23,6 @@ from graphon.model_runtime.errors.invoke import ( ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from graphon.model_runtime.utils.encoders import jsonable_encoder - -from core.model_manager import ModelManager -from core.tools.entities.tool_entities import ToolProviderType -from extensions.ext_database import db from models.tools import ToolModelInvoke diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 2159eb8638..45718cadb6 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,13 +1,12 @@ from collections.abc import Mapping, Sequence from typing import Any +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.errors import WorkflowToolHumanInputNotSupportedError from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.entities import OutputVariableEntity from graphon.variables.input_entities import VariableEntity -from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError - class WorkflowToolConfigurationUtils: @classmethod diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index a01004448a..5905fd919e 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -2,7 +2,6 @@ from __future__ import annotations from collections.abc import Mapping -from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field from sqlalchemy import select from sqlalchemy.orm import Session @@ -25,6 +24,7 @@ from core.tools.entities.tool_entities import ( from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 7c4f8ee03a..52ab605963 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,8 +5,6 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -22,6 +20,8 @@ from core.tools.entities.tool_entities import ( from core.tools.errors import ToolInvokeError from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from models import Account, Tenant from models.model import App, EndUser from models.utils.file_input_compat import build_file_from_stored_mapping diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 61d1cd8540..24c1271488 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -8,7 +8,6 @@ from collections.abc import Mapping from datetime import datetime from typing import Any -from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse @@ -28,6 +27,7 @@ from core.trigger.debug.events import ( from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at from models.model import App diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index c52aad150b..51452c29a3 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,12 +1,12 @@ from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index d9247b2593..e4f6b3b470 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,6 +1,12 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.datasource.datasource_manager import DatasourceManager +from core.datasource.entities.datasource_entities import DatasourceProviderType +from core.plugin.impl.exc import PluginDaemonClientSideError +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( BuiltinNodeTypes, @@ -12,13 +18,6 @@ from graphon.node_events import NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.datasource.datasource_manager import DatasourceManager -from core.datasource.entities.datasource_entities import DatasourceProviderType -from core.plugin.impl.exc import PluginDaemonClientSideError -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.system_variables import SystemVariableKey, get_system_segment - from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index cad32f8d5b..28966f2392 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,9 +1,10 @@ from typing import Any, Literal, Union +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo + from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo class DatasourceEntity(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 04a10f9257..260881e49c 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,13 +1,13 @@ from typing import Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel from core.rag.entities import RerankingModelConfig, WeightedScoreConfig from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class RetrievalSetting(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index bb72fe3881..d5cab05dbe 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,17 +2,16 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template - from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 460ec693ce..3825f526a2 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,11 +1,11 @@ from typing import Literal -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.llm.entities import ModelConfig, VisionConfig from pydantic import BaseModel, Field from core.rag.entities import Condition, MetadataFilteringCondition, RerankingModelConfig, WeightedScoreConfig +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.llm.entities import ModelConfig, VisionConfig __all__ = ["Condition"] diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 13624b27b3..47ad14b499 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,6 +8,11 @@ import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference from graphon.entities import GraphInitParams from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( @@ -27,12 +32,6 @@ from graphon.variables import ( ) from graphon.variables.segments import ArrayObjectSegment -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file_reference import parse_file_reference - from .entities import ( Condition, KnowledgeRetrievalNodeData, diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index bf5be2379a..23ed2cd408 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index e50de11bb9..c848a86255 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,13 +1,12 @@ from collections.abc import Mapping from typing import Any +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node -from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID - from .entities import TriggerEventNodeData diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index 04f1f7e6bb..683c8d420f 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -1,10 +1,10 @@ from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index a9753ab387..b46cc76a6e 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,10 @@ from collections.abc import Mapping -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node - from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node from .entities import TriggerScheduleNodeData diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index a30f877e4b..b261039448 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from enum import StrEnum -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field, field_validator from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType +from graphon.variables.types import SegmentType _WEBHOOK_HEADER_ALLOWED_TYPES: frozenset[SegmentType] = frozenset((SegmentType.STRING,)) diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index d942a718cc..13c4f05bfd 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,6 +2,10 @@ import logging from collections.abc import Mapping from typing import Any +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment_with_type from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus from graphon.file import FileTransferMethod from graphon.node_events import NodeRunResult @@ -10,11 +14,6 @@ from graphon.nodes.protocols import FileReferenceFactoryProtocol from graphon.variables.types import SegmentType from graphon.variables.variables import FileVariable -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from factories.variable_factory import build_segment_with_type - from .entities import ContentType, WebhookData logger = logging.getLogger(__name__) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index b7e7a6e60f..0c535a1c5b 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -6,9 +6,9 @@ import click from sqlalchemy import select from werkzeug.exceptions import NotFound +from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.document_index_event import document_index_created -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document from models.enums import IndexingStatus @@ -22,24 +22,25 @@ def handle(sender, **kwargs): document_ids = kwargs.get("document_ids", []) documents = [] start_at = time.perf_counter() - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) + with session_factory.create_session() as session: + for document_id in document_ids: + logger.info(click.style(f"Start process document: {document_id}", fg="green")) - document = db.session.scalar( - select(Document).where( - Document.id == document_id, - Document.dataset_id == dataset_id, + document = session.scalar( + select(Document).where( + Document.id == document_id, + Document.dataset_id == dataset_id, + ) ) - ) - if not document: - raise NotFound("Document not found") + if not document: + raise NotFound("Document not found") - document.indexing_status = IndexingStatus.PARSING - document.processing_started_at = naive_utc_now() - documents.append(document) - db.session.add(document) - db.session.commit() + document.indexing_status = IndexingStatus.PARSING + document.processing_started_at = naive_utc_now() + documents.append(document) + session.add(document) + session.commit() with contextlib.suppress(Exception): try: diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 84be592b1a..5e2a456dce 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -1,5 +1,5 @@ +from core.db.session_factory import session_factory from events.app_event import app_was_created -from extensions.ext_database import db from models.enums import CustomizeTokenStrategy from models.model import Site @@ -22,6 +22,6 @@ def handle(sender, **kwargs): created_by=app.created_by, updated_by=app.updated_by, ) - - db.session.add(site) - db.session.commit() + with session_factory.create_session() as session: + session.add(site) + session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 7bd8e88231..ba9758175f 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,12 +1,11 @@ import logging -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.tool.entities import ToolEntity - from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_draft_workflow_was_synced +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 86b5b2bbf0..6769b94cde 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,11 +1,11 @@ from typing import cast -from graphon.nodes import BuiltinNodeTypes from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db +from graphon.nodes import BuiltinNodeTypes from models.dataset import AppDatasetJoin from models.workflow import Workflow diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 5cc58f27c4..69d1f1ab07 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,11 +5,12 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk - from graphon.model_runtime.errors.invoke import InvokeRateLimitError from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException + from graphon.model_runtime.errors.invoke import InvokeRateLimitError + try: from langfuse._utils import parse_error diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index db599c5d49..64ff0f0674 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value +from graphon.enums import WorkflowNodeExecutionStatus from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 2745141431..7f77a0437a 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -20,12 +20,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, cast -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string +from graphon.enums import WorkflowExecutionStatus from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index d0f3e2e244..544109276d 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -3,14 +3,14 @@ import logging import os import time -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index 23d324f9ea..fbf379b3e5 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -10,17 +10,17 @@ Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when import json from typing import Any, Protocol -from graphon.enums import BuiltinNodeTypes -from graphon.file import File -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes +from graphon.enums import BuiltinNodeTypes +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment def should_include_content() -> bool: diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index 335c5cc29e..ec3c78a12d 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -6,12 +6,12 @@ import logging from collections.abc import Mapping from typing import Any -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index 6df5f62c15..56672d1fd4 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -6,13 +6,13 @@ import logging from collections.abc import Sequence from typing import Any -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index b9fdd9e1ca..75ddbba448 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -2,14 +2,14 @@ Parser for tool nodes that captures tool-specific metadata. """ -from graphon.enums import WorkflowNodeExecutionMetadataKey -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.nodes.tool.entities import ToolNodeData from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import ToolAttributes +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.nodes.tool.entities import ToolNodeData class ToolNodeOTelParser: diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py index 288d37d265..ce1fa441c2 100644 --- a/api/factories/file_factory/builders.py +++ b/api/factories/file_factory/builders.py @@ -10,8 +10,8 @@ from typing import Any from sqlalchemy import select from core.app.file_access import FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.workflow.file_reference import build_file_reference -from extensions.ext_database import db from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from models import ToolFile, UploadFile @@ -135,29 +135,30 @@ def _build_from_local_file( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if row is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + row = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type", "custom"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - reference=build_file_reference(record_id=str(row.id)), - size=row.size, - storage_key=row.key, - ) + return File( + id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) def _build_from_remote_url( @@ -179,32 +180,33 @@ def _build_from_remote_url( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if upload_file is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type( - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - ) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - reference=build_file_reference(record_id=str(upload_file.id)), - size=upload_file.size, - storage_key=upload_file.key, - ) + return File( + id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) url = mapping.get("url") or mapping.get("remote_url") if not url: @@ -247,30 +249,31 @@ def _build_from_tool_file( ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) - tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") + with session_factory.create_session() as session: + tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - reference=build_file_reference(record_id=str(tool_file.id)), - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) + return File( + id=mapping.get("id"), + filename=tool_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) def _build_from_datasource_file( @@ -289,31 +292,32 @@ def _build_from_datasource_file( UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) - datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + with session_factory.create_session() as session: + datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("datasource_file_id"), - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - reference=build_file_reference(record_id=str(datasource_file.id)), - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) + return File( + id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 57205b5739..fd7acb14d3 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -8,6 +8,11 @@ shared conversion functions for legacy callers and tests. from collections.abc import Mapping, Sequence from typing import Any, cast +from configs import dify_config +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) from graphon.variables.exc import VariableError from graphon.variables.factory import ( TypeMismatchError, @@ -31,12 +36,6 @@ from graphon.variables.variables import ( VariableBase, ) -from configs import dify_config -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, -) - __all__ = [ "TypeMismatchError", "UnsupportedSegmentTypeError", diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index cfe0015918..67b320beaa 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -3,10 +3,10 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields -from graphon.file import helpers as file_helpers from pydantic import computed_field, field_validator from fields.base import ResponseModel +from graphon.file import helpers as file_helpers simple_account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 1a871204a0..ca18f1c203 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,12 +3,12 @@ from __future__ import annotations from datetime import datetime from uuid import uuid4 -from graphon.file import File from pydantic import Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from fields.base import ResponseModel from fields.conversation_fields import AgentThought, JSONValue, MessageFile +from graphon.file import File type JSONValueType = JSONValue diff --git a/api/fields/raws.py b/api/fields/raws.py index 4c65cdab7a..ee6f53b360 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,4 +1,5 @@ from flask_restx import fields + from graphon.file import File diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index b0b6cc0b48..f9b5e98936 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,8 +1,8 @@ from flask_restx import fields -from graphon.variables import SecretVariable, SegmentType, VariableBase from core.helper import encrypter from fields.member_fields import simple_account_fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type diff --git a/api/libs/helper.py b/api/libs/helper.py index 69bd483515..ac69a11084 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,8 +16,6 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields -from graphon.file import helpers as file_helpers -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, TypeAdapter from pydantic.functional_validators import AfterValidator from typing_extensions import TypedDict @@ -25,6 +23,8 @@ from typing_extensions import TypedDict from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from extensions.ext_redis import redis_client +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder if TYPE_CHECKING: from models import Account diff --git a/api/models/dataset.py b/api/models/dataset.py index 50301dd2d7..eee5c39a0e 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1715,7 +1715,7 @@ class SegmentAttachmentBinding(TypeBase): ) -class DocumentSegmentSummary(Base): +class DocumentSegmentSummary(TypeBase): __tablename__ = "document_segment_summaries" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"), @@ -1725,25 +1725,40 @@ class DocumentSegmentSummary(Base): sa.Index("document_segment_summaries_status_idx", "status"), ) - id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + insert_default=lambda: str(uuid4()), + default_factory=lambda: str(uuid4()), + init=False, + ) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # corresponds to DocumentSegment.id or parent chunk id chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - summary_content: Mapped[str] = mapped_column(LongText, nullable=True) - summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) - summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) - tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - status: Mapped[str] = mapped_column( - EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'") + summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + status: Mapped[SummaryStatus] = mapped_column( + EnumText(SummaryStatus, length=32), + nullable=False, + server_default=sa.text("'generating'"), + default=SummaryStatus.GENERATING, + ) + error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) - error: Mapped[str] = mapped_column(LongText, nullable=True) - enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) - disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - disabled_by = mapped_column(StringUUID, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) def __repr__(self): diff --git a/api/models/evaluation.py b/api/models/evaluation.py index 680d6ab31c..fce50c5f48 100644 --- a/api/models/evaluation.py +++ b/api/models/evaluation.py @@ -85,7 +85,7 @@ class EvaluationConfiguration(Base): """Return judgment config (stored in the judgement_conditions column).""" if self.judgement_conditions: parsed = json.loads(self.judgement_conditions) - return parsed if parsed else None + return parsed or None return None @property diff --git a/api/models/human_input.py b/api/models/human_input.py index 79c5d62f6a..b4c7a634b6 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -3,11 +3,11 @@ from enum import StrEnum from typing import Annotated, Literal, Self, final import sqlalchemy as sa -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship from core.workflow.human_input_compat import DeliveryMethodType +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index 8eabf45363..7fe0731098 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,9 +14,6 @@ from uuid import uuid4 import sqlalchemy as sa from flask import request from flask_login import UserMixin # type: ignore[import-untyped] -from graphon.enums import WorkflowExecutionStatus -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from graphon.file import helpers as file_helpers from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker @@ -24,6 +21,9 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from extensions.storage.storage_type import StorageType +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 from models.utils.file_input_compat import build_file_from_input_mapping diff --git a/api/models/provider.py b/api/models/provider.py index 8270961b31..2bb67d605b 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,10 +6,10 @@ from functools import cached_property from uuid import uuid4 import sqlalchemy as sa -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import DateTime, String, func, select, text from sqlalchemy.orm import Mapped, mapped_column +from graphon.model_runtime.entities.model_entities import ModelType from libs.uuid_utils import uuidv7 from .base import TypeBase diff --git a/api/models/workflow.py b/api/models/workflow.py index 020adaa9a8..467d60f6ac 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -8,19 +8,6 @@ from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import ( - BuiltinNodeTypes, - NodeType, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File -from graphon.file.constants import maybe_file_object -from graphon.variables import utils as variable_utils -from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from sqlalchemy import ( DateTime, Index, @@ -44,6 +31,19 @@ from core.workflow.variable_prefixes import ( ) from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import File +from graphon.file.constants import maybe_file_object +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -53,11 +53,10 @@ if TYPE_CHECKING: from .model import AppMode, UploadFile -from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase - from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter from factories import variable_factory +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase from libs import helper from .account import Account diff --git a/api/providers/README.md b/api/providers/README.md index a00ec8bc52..5d5e6db9af 100644 --- a/api/providers/README.md +++ b/api/providers/README.md @@ -10,3 +10,6 @@ This directory holds **optional workspace packages** that plug into Dify’s API Provider tests often live next to the package, e.g. `providers///tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`). +## Excluding Providers + +In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-` and `--group trace-` to enable specific providers. diff --git a/api/providers/trace/README.md b/api/providers/trace/README.md new file mode 100644 index 0000000000..a7ffa5ed26 --- /dev/null +++ b/api/providers/trace/README.md @@ -0,0 +1,78 @@ +# Trace providers + +This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others). + +Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping. + +## Architecture + +| Layer | Location | Role | +|--------|----------|------| +| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads | +| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class | +| Your package | `api/providers/trace/trace-/` | Pydantic config + subclass of `BaseTraceInstance` | + +At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype. + +## What you implement + +### 1. Config model (`BaseTracingConfig`) + +Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate. + +Fields fall into two groups used by the manager: + +- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords). +- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints). + +List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct. + +### 2. Trace instance (`BaseTraceInstance`) + +Subclass `BaseTraceInstance` and implement: + +```python +def trace(self, trace_info: BaseTraceInfo) -> None: + ... +``` + +Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including: + +- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace` +- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo` +- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo` + +You may ignore categories your backend does not support; existing providers often no-op unhandled types. + +Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context. + +### 3. Register in the API core + +Upstream changes are required so Dify knows your provider exists: + +1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`). +2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning: + - `config_class`: your Pydantic config type + - `secret_keys` / `other_keys`: lists of field names as above + - `trace_instance`: your `BaseTraceInstance` subclass + Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`. + +If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app. + +## Package layout + +Each provider is a normal uv workspace member, for example: + +- `api/providers/trace/trace-/pyproject.toml` — project name `dify-trace-`, dependencies on vendor SDKs +- `api/providers/trace/trace-/src/dify_trace_/` — `config.py`, `_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`. + +Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`. + +## Wiring into the `api` workspace + +In `api/pyproject.toml`: + +1. **`[tool.uv.sources]`** — `dify-trace- = { workspace = true }` +2. **`[dependency-groups]`** — add `trace- = ["dify-trace-"]` and include `dify-trace-` in `trace-all` if it should ship with the default bundle + +After changing metadata, run **`uv sync`** from `api/`. diff --git a/api/providers/trace/trace-aliyun/pyproject.toml b/api/providers/trace/trace-aliyun/pyproject.toml new file mode 100644 index 0000000000..bcef7e9fb1 --- /dev/null +++ b/api/providers/trace/trace-aliyun/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-aliyun" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Aliyun)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/aliyun_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py similarity index 98% rename from api/core/ops/aliyun_trace/aliyun_trace.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py index 70aaf2a07b..54d2f8167f 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py @@ -1,12 +1,23 @@ import logging from collections.abc import Sequence -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.data_exporter.traceclient import ( TraceClient, build_endpoint, convert_datetime_to_nanoseconds, @@ -14,8 +25,8 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( convert_to_trace_id, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata +from dify_trace_aliyun.entities.semconv import ( DIFY_APP_ID, GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, @@ -34,7 +45,7 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -46,20 +57,9 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) -from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import AliyunConfig -from core.ops.entities.trace_entity import ( - BaseTraceInfo, - DatasetRetrievalTraceInfo, - GenerateNameTraceInfo, - MessageTraceInfo, - ModerationTraceInfo, - SuggestedQuestionTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py new file mode 100644 index 0000000000..e0133e6cc9 --- /dev/null +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py @@ -0,0 +1,32 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class AliyunConfig(BaseTracingConfig): + """ + Model class for Aliyun tracing config. + """ + + app_name: str = "dify_app" + license_key: str + endpoint: str + + @field_validator("app_name") + @classmethod + def app_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") + + @field_validator("license_key") + @classmethod + def license_key_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("License key cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # aliyun uses two URL formats, which may include a URL path + return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") diff --git a/api/core/ops/aliyun_trace/data_exporter/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/data_exporter/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py similarity index 98% rename from api/core/ops/aliyun_trace/data_exporter/traceclient.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py index 67d5163b0f..00aab6bf89 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py @@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes from opentelemetry.trace import Link, SpanContext, TraceFlags from configs import dify_config -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData -from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE INVALID_SPAN_ID: Final[int] = 0x0000000000000000 INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000 diff --git a/api/core/ops/aliyun_trace/entities/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/semconv.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py diff --git a/api/core/ops/arize_phoenix_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed similarity index 100% rename from api/core/ops/arize_phoenix_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed diff --git a/api/core/ops/aliyun_trace/utils.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py similarity index 97% rename from api/core/ops/aliyun_trace/utils.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py index aa35ac74c2..5678c66adb 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py @@ -2,11 +2,10 @@ import json from collections.abc import Mapping from typing import Any, TypedDict -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, Status, StatusCode -from core.ops.aliyun_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -15,8 +14,9 @@ from core.ops.aliyun_trace.entities.semconv import ( OUTPUT_VALUE, GenAISpanKind, ) -from core.rag.models.document import Document from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser # Constants @@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status: def create_links_from_trace_id(trace_id: str | None) -> list[Link]: - from core.ops.aliyun_trace.data_exporter.traceclient import create_link + from dify_trace_aliyun.data_exporter.traceclient import create_link links = [] if trace_id: diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py similarity index 86% rename from api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py index acb43d4036..286dda419c 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py @@ -5,10 +5,7 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from opentelemetry.sdk.trace import ReadableSpan -from opentelemetry.trace import SpanKind, Status, StatusCode - -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from dify_trace_aliyun.data_exporter.traceclient import ( INVALID_SPAN_ID, SpanBuilder, TraceClient, @@ -20,7 +17,9 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( create_link, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.trace import SpanKind, Status, StatusCode @pytest.fixture @@ -41,8 +40,8 @@ def trace_client_factory(): class TestTraceClient: - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname") def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory): mock_gethostname.return_value = "test-host" client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -56,7 +55,7 @@ class TestTraceClient: client.shutdown() assert client.done is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -64,8 +63,8 @@ class TestTraceClient: client.export(spans) mock_exporter.export.assert_called_once_with(spans) - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 405 @@ -74,8 +73,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 500 @@ -84,8 +83,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is False - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory): mock_head.side_effect = httpx.RequestError("Connection error") @@ -93,12 +92,12 @@ class TestTraceClient: with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"): client.api_check() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_get_project_url(self, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm" - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_add_span(self, mock_exporter_class, trace_client_factory): client = trace_client_factory( service_name="test-service", @@ -134,8 +133,8 @@ class TestTraceClient: assert len(client.queue) == 2 mock_notify.assert_called_once() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.logger") def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1) @@ -159,7 +158,7 @@ class TestTraceClient: assert len(client.queue) == 1 mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export_batch_error(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value mock_exporter.export.side_effect = Exception("Export failed") @@ -168,11 +167,11 @@ class TestTraceClient: mock_span = MagicMock(spec=ReadableSpan) client.queue.append(mock_span) - with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger: + with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger: client._export_batch() mock_logger.warning.assert_called() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_worker_loop(self, mock_exporter_class, trace_client_factory): # We need to test the wait timeout in _worker # But _worker runs in a thread. Let's mock condition.wait. @@ -189,7 +188,7 @@ class TestTraceClient: # mock_wait might have been called assert mock_wait.called or client.done - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -268,7 +267,7 @@ def test_generate_span_id(): assert span_id != INVALID_SPAN_ID # Test retry loop - with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand: + with patch("dify_trace_aliyun.data_exporter.traceclient.random.getrandbits") as mock_rand: mock_rand.side_effect = [INVALID_SPAN_ID, 999] span_id = generate_span_id() assert span_id == 999 @@ -290,7 +289,7 @@ def test_convert_to_trace_id(): def test_convert_string_to_id(): assert convert_string_to_id("test") > 0 # Test with None string - with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen: + with patch("dify_trace_aliyun.data_exporter.traceclient.generate_span_id") as mock_gen: mock_gen.return_value = 12345 assert convert_string_to_id(None) == 12345 diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py similarity index 97% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py index 2fcb927e0c..38d33dd21b 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py @@ -1,11 +1,10 @@ import pytest +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import Event from opentelemetry.trace import SpanKind, Status, StatusCode from pydantic import ValidationError -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata - class TestTraceMetadata: def test_trace_metadata_init(self): diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py similarity index 97% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py index 3961555b9a..9cab40748f 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py @@ -1,4 +1,4 @@ -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( ACS_ARMS_SERVICE_FEATURE, GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py similarity index 99% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py index c2324fdec4..c1b11c9186 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py @@ -4,12 +4,11 @@ from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import MagicMock +import dify_trace_aliyun.aliyun_trace as aliyun_trace_module import pytest -from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags - -import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module -from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.aliyun_trace import AliyunDataTrace +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, GEN_AI_OUTPUT_MESSAGE, @@ -24,7 +23,8 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.entities.config_entity import AliyunConfig +from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py similarity index 95% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py index e4d8f2d5ea..a9e7b80c2a 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py @@ -1,9 +1,7 @@ import json from unittest.mock import MagicMock -from opentelemetry.trace import Link, StatusCode - -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -11,7 +9,7 @@ from core.ops.aliyun_trace.entities.semconv import ( INPUT_VALUE, OUTPUT_VALUE, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -23,6 +21,8 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) +from opentelemetry.trace import Link, StatusCode + from core.rag.models.document import Document from graphon.entities import WorkflowNodeExecution from graphon.enums import WorkflowNodeExecutionStatus @@ -48,7 +48,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = end_user_data - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -63,7 +63,7 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = None - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -112,9 +112,9 @@ def test_get_workflow_node_status(): def test_create_links_from_trace_id(monkeypatch): # Mock create_link mock_link = MagicMock(spec=Link) - import core.ops.aliyun_trace.data_exporter.traceclient + import dify_trace_aliyun.data_exporter.traceclient - monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) + monkeypatch.setattr(dify_trace_aliyun.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) # Trace ID None assert create_links_from_trace_id(None) == [] diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..1b24ee7421 --- /dev/null +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py @@ -0,0 +1,85 @@ +import pytest +from dify_trace_aliyun.config import AliyunConfig +from pydantic import ValidationError + + +class TestAliyunConfig: + """Test cases for AliyunConfig""" + + def test_valid_config(self): + """Test valid Aliyun configuration""" + config = AliyunConfig( + app_name="test_app", + license_key="test_license_key", + endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", + ) + assert config.app_name == "test_app" + assert config.license_key == "test_license_key" + assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + assert config.app_name == "dify_app" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + AliyunConfig() + + with pytest.raises(ValidationError): + AliyunConfig(license_key="test_license") + + with pytest.raises(ValidationError): + AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + + def test_app_name_validation_empty(self): + """Test app_name validation with empty value""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" + ) + assert config.app_name == "dify_app" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = AliyunConfig(license_key="test_license", endpoint="") + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation preserves path for Aliyun endpoints""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + ) + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_license_key_required(self): + """Test that license_key is required and cannot be empty""" + with pytest.raises(ValidationError): + AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + + def test_valid_endpoint_format_examples(self): + """Test valid endpoint format examples from comments""" + valid_endpoints = [ + # cms2.0 public endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", + # cms2.0 intranet endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", + # xtrace public endpoint + "http://tracing-cn-heyuan.arms.aliyuncs.com", + # xtrace intranet endpoint + "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", + ] + + for endpoint in valid_endpoints: + config = AliyunConfig(license_key="test_license", endpoint=endpoint) + assert config.endpoint == endpoint diff --git a/api/providers/trace/trace-arize-phoenix/pyproject.toml b/api/providers/trace/trace-arize-phoenix/pyproject.toml new file mode 100644 index 0000000000..9e756944c9 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-arize-phoenix" +version = "0.0.1" +dependencies = [ + "arize-phoenix-otel~=0.15.0", +] +description = "Dify ops tracing provider (Arize / Phoenix)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langfuse_trace/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py similarity index 100% rename from api/core/ops/langfuse_trace/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py similarity index 99% rename from api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index dd5edde630..96df49ed0e 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -6,7 +6,6 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse -from graphon.enums import WorkflowNodeExecutionStatus from openinference.semconv.trace import ( MessageAttributes, OpenInferenceMimeTypeValues, @@ -26,7 +25,6 @@ from opentelemetry.util.types import AttributeValue from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -40,7 +38,9 @@ from core.ops.entities.trace_entity import ( ) from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import DifyCoreRepositoryFactory +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py new file mode 100644 index 0000000000..6eac5b30d2 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py @@ -0,0 +1,45 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class ArizeConfig(BaseTracingConfig): + """ + Model class for Arize tracing config. + """ + + api_key: str | None = None + space_id: str | None = None + project: str | None = None + endpoint: str = "https://otlp.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://otlp.arize.com") + + +class PhoenixConfig(BaseTracingConfig): + """ + Model class for Phoenix tracing config. + """ + + api_key: str | None = None + project: str | None = None + endpoint: str = "https://app.phoenix.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://app.phoenix.arize.com") diff --git a/api/core/ops/langfuse_trace/entities/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed similarity index 100% rename from api/core/ops/langfuse_trace/entities/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index 4ce9e22fd7..b0691a87ea 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -2,11 +2,7 @@ from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, patch import pytest -from opentelemetry.sdk.trace import Tracer -from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes -from opentelemetry.trace import StatusCode - -from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( +from dify_trace_arize_phoenix.arize_phoenix_trace import ( ArizePhoenixDataTrace, datetime_to_nanos, error_to_string, @@ -15,7 +11,11 @@ from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( setup_tracer, wrap_span_metadata, ) -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from opentelemetry.sdk.trace import Tracer +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.trace import StatusCode + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -80,7 +80,7 @@ def test_datetime_to_nanos(): expected = int(dt.timestamp() * 1_000_000_000) assert datetime_to_nanos(dt) == expected - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.datetime") as mock_dt: mock_now = MagicMock() mock_now.timestamp.return_value = 1704110400.0 mock_dt.now.return_value = mock_now @@ -142,8 +142,8 @@ def test_wrap_span_metadata(): assert res == {"a": 1, "b": 2, "created_from": "Dify"} -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_arize(mock_provider, mock_exporter): config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") setup_tracer(config) @@ -151,8 +151,8 @@ def test_setup_tracer_arize(mock_provider, mock_exporter): assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1" -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.HttpOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_phoenix(mock_provider, mock_exporter): config = PhoenixConfig(endpoint="http://p.com", project="p") setup_tracer(config) @@ -162,7 +162,7 @@ def test_setup_tracer_phoenix(mock_provider, mock_exporter): def test_setup_tracer_exception(): config = ArizeConfig(endpoint="http://a.com", project="p") - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): with pytest.raises(Exception, match="boom"): setup_tracer(config) @@ -172,7 +172,7 @@ def test_setup_tracer_exception(): @pytest.fixture def trace_instance(): - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup: mock_tracer = MagicMock(spec=Tracer) mock_processor = MagicMock() mock_setup.return_value = (mock_tracer, mock_processor) @@ -228,9 +228,9 @@ def test_trace_exception(trace_instance): trace_instance.trace(_make_workflow_info()) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -262,7 +262,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac assert trace_instance.tracer.start_span.call_count >= 2 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_no_app_id(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -271,7 +271,7 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance): trace_instance.workflow_trace(info) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_success(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() @@ -291,7 +291,7 @@ def test_message_trace_success(mock_db, trace_instance): assert trace_instance.tracer.start_span.call_count >= 1 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_with_error(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py similarity index 94% rename from api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py index 4b925390d9..a01c63ae61 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py @@ -1,6 +1,6 @@ +from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind from openinference.semconv.trace import OpenInferenceSpanKindValues -from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..11e951c3b1 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py @@ -0,0 +1,88 @@ +import pytest +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from pydantic import ValidationError + + +class TestArizeConfig: + """Test cases for ArizeConfig""" + + def test_valid_config(self): + """Test valid Arize configuration""" + config = ArizeConfig( + api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" + ) + assert config.api_key == "test_key" + assert config.space_id == "test_space" + assert config.project == "test_project" + assert config.endpoint == "https://custom.arize.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = ArizeConfig() + assert config.api_key is None + assert config.space_id is None + assert config.project is None + assert config.endpoint == "https://otlp.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = ArizeConfig(project="") + assert config.project == "default" + + def test_project_validation_none(self): + """Test project validation with None value""" + config = ArizeConfig(project=None) + assert config.project == "default" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = ArizeConfig(endpoint="") + assert config.endpoint == "https://otlp.arize.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation normalizes URL by removing path""" + config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") + assert config.endpoint == "https://custom.arize.com" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="ftp://invalid.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="invalid.com") + + +class TestPhoenixConfig: + """Test cases for PhoenixConfig""" + + def test_valid_config(self): + """Test valid Phoenix configuration""" + config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.phoenix.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = PhoenixConfig() + assert config.api_key is None + assert config.project is None + assert config.endpoint == "https://app.phoenix.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = PhoenixConfig(project="") + assert config.project == "default" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation with path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + + def test_endpoint_validation_without_path(self): + """Test endpoint validation without path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") + assert config.endpoint == "https://app.phoenix.arize.com" diff --git a/api/providers/trace/trace-langfuse/pyproject.toml b/api/providers/trace/trace-langfuse/pyproject.toml new file mode 100644 index 0000000000..27d2273a69 --- /dev/null +++ b/api/providers/trace/trace-langfuse/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langfuse" +version = "0.0.1" +dependencies = [ + "langfuse>=4.2.0,<5.0.0", +] +description = "Dify ops tracing provider (Langfuse)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langsmith_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py diff --git a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py new file mode 100644 index 0000000000..90d1a2846b --- /dev/null +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py @@ -0,0 +1,19 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class LangfuseConfig(BaseTracingConfig): + """ + Model class for Langfuse tracing config. + """ + + public_key: str + secret_key: str + host: str = "https://api.langfuse.com" + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://api.langfuse.com") diff --git a/api/core/ops/langsmith_trace/entities/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py similarity index 100% rename from api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py similarity index 99% rename from api/core/ops/langfuse_trace/langfuse_trace.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py index 7eacc2be46..68881378a7 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py @@ -16,7 +16,6 @@ from langfuse.api.commons.types.usage import Usage from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -28,7 +27,10 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( +from core.ops.utils import filter_none_values +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( GenerationUsage, LangfuseGeneration, LangfuseSpan, @@ -36,8 +38,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( LevelEnum, UnitEnum, ) -from core.ops.utils import filter_none_values -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db from graphon.enums import BuiltinNodeTypes from models import EndUser, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/mlflow_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed similarity index 100% rename from api/core/ops/mlflow_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py rename to api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py index a0bcc92795..952f10c34f 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py @@ -5,8 +5,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( + LangfuseGeneration, + LangfuseSpan, + LangfuseTrace, + LevelEnum, + UnitEnum, +) +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -17,14 +25,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( - LangfuseGeneration, - LangfuseSpan, - LangfuseTrace, - LevelEnum, - UnitEnum, -) -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace from graphon.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -43,7 +43,7 @@ def langfuse_config(): def trace_instance(langfuse_config, monkeypatch): # Mock Langfuse client to avoid network calls mock_client = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client) instance = LangFuseDataTrace(langfuse_config) return instance @@ -51,7 +51,7 @@ def trace_instance(langfuse_config, monkeypatch): def test_init(langfuse_config, monkeypatch): mock_langfuse = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangFuseDataTrace(langfuse_config) @@ -140,8 +140,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): # Mock DB and Repositories mock_session = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -178,7 +178,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -241,13 +241,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -280,8 +280,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="log-1", error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -365,7 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock() trace_instance.add_generation = MagicMock() @@ -681,9 +681,9 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..103d888eef --- /dev/null +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py @@ -0,0 +1,42 @@ +import pytest +from dify_trace_langfuse.config import LangfuseConfig +from pydantic import ValidationError + + +class TestLangfuseConfig: + """Test cases for LangfuseConfig""" + + def test_valid_config(self): + """Test valid Langfuse configuration""" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == "https://custom.langfuse.com" + + def test_valid_config_with_path(self): + host = "https://custom.langfuse.com/api/v1" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == host + + def test_default_values(self): + """Test default values are set correctly""" + config = LangfuseConfig(public_key="public", secret_key="secret") + assert config.host == "https://api.langfuse.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangfuseConfig() + + with pytest.raises(ValidationError): + LangfuseConfig(public_key="public") + + with pytest.raises(ValidationError): + LangfuseConfig(secret_key="secret") + + def test_host_validation_empty(self): + """Test host validation with empty value""" + config = LangfuseConfig(public_key="public", secret_key="secret", host="") + assert config.host == "https://api.langfuse.com" diff --git a/api/tests/unit_tests/core/ops/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py similarity index 92% rename from api/tests/unit_tests/core/ops/test_langfuse_trace.py rename to api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py index 017ac8c891..0340ffb669 100644 --- a/api/tests/unit_tests/core/ops/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py @@ -4,14 +4,15 @@ from datetime import datetime, timedelta from types import SimpleNamespace from unittest.mock import MagicMock, patch -from core.ops.entities.config_entity import LangfuseConfig +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace + from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace from graphon.enums import BuiltinNodeTypes def _create_trace_instance() -> LangFuseDataTrace: - with patch("core.ops.langfuse_trace.langfuse_trace.Langfuse", autospec=True): + with patch("dify_trace_langfuse.langfuse_trace.Langfuse", autospec=True): return LangFuseDataTrace( LangfuseConfig( public_key="public-key", @@ -116,9 +117,9 @@ class TestLangFuseDataTraceCompletionStartTime: patch.object(trace, "add_span"), patch.object(trace, "add_generation") as add_generation, patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()), - patch("core.ops.langfuse_trace.langfuse_trace.db", MagicMock()), + patch("dify_trace_langfuse.langfuse_trace.db", MagicMock()), patch( - "core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=repository, ), ): diff --git a/api/providers/trace/trace-langsmith/pyproject.toml b/api/providers/trace/trace-langsmith/pyproject.toml new file mode 100644 index 0000000000..8131952b28 --- /dev/null +++ b/api/providers/trace/trace-langsmith/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langsmith" +version = "0.0.1" +dependencies = [ + "langsmith~=0.7.30", +] +description = "Dify ops tracing provider (LangSmith)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/opik_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py similarity index 100% rename from api/core/ops/opik_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py new file mode 100644 index 0000000000..498b8c5e7e --- /dev/null +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py @@ -0,0 +1,20 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class LangSmithConfig(BaseTracingConfig): + """ + Model class for Langsmith tracing config. + """ + + api_key: str + project: str + endpoint: str = "https://api.smith.langchain.com" + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # LangSmith only allows HTTPS + return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) diff --git a/api/core/ops/tencent_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py similarity index 99% rename from api/core/ops/langsmith_trace/langsmith_trace.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py index 490c64af84..145bd70dbc 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py @@ -4,13 +4,11 @@ import uuid from datetime import datetime, timedelta from typing import cast -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from langsmith import Client from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -22,14 +20,16 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( +from core.ops.utils import filter_none_values, generate_dotted_order +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( LangSmithRunModel, LangSmithRunType, LangSmithRunUpdateModel, ) -from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/weave_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed similarity index 100% rename from api/core/ops/weave_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py rename to api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py index 34c64c54a1..45e5894e4a 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py @@ -3,8 +3,14 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( + LangSmithRunModel, + LangSmithRunType, + LangSmithRunUpdateModel, +) +from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -15,12 +21,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( - LangSmithRunModel, - LangSmithRunType, - LangSmithRunUpdateModel, -) -from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -38,7 +38,7 @@ def langsmith_config(): def trace_instance(langsmith_config, monkeypatch): # Mock LangSmith client mock_client = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client) instance = LangSmithDataTrace(langsmith_config) return instance @@ -46,7 +46,7 @@ def trace_instance(langsmith_config, monkeypatch): def test_init(langsmith_config, monkeypatch): mock_client_class = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangSmithDataTrace(langsmith_config) @@ -138,8 +138,8 @@ def test_workflow_trace(trace_instance, monkeypatch): # Mock dependencies mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -188,7 +188,7 @@ def test_workflow_trace(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -252,13 +252,13 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() @@ -283,8 +283,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): trace_info.error = "" mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -319,7 +319,7 @@ def test_message_trace(trace_instance, monkeypatch): # Mock EndUser lookup mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_run = MagicMock() @@ -567,9 +567,9 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..37efaf69cf --- /dev/null +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py @@ -0,0 +1,35 @@ +import pytest +from dify_trace_langsmith.config import LangSmithConfig +from pydantic import ValidationError + + +class TestLangSmithConfig: + """Test cases for LangSmithConfig""" + + def test_valid_config(self): + """Test valid LangSmith configuration""" + config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.smith.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = LangSmithConfig(api_key="key", project="project") + assert config.endpoint == "https://api.smith.langchain.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangSmithConfig() + + with pytest.raises(ValidationError): + LangSmithConfig(api_key="key") + + with pytest.raises(ValidationError): + LangSmithConfig(project="project") + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") diff --git a/api/providers/trace/trace-mlflow/pyproject.toml b/api/providers/trace/trace-mlflow/pyproject.toml new file mode 100644 index 0000000000..fad6002944 --- /dev/null +++ b/api/providers/trace/trace-mlflow/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-mlflow" +version = "0.0.1" +dependencies = [ + "mlflow-skinny>=3.11.1", +] +description = "Dify ops tracing provider (MLflow / Databricks)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/weave_trace/entities/__init__.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py similarity index 100% rename from api/core/ops/weave_trace/entities/__init__.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py new file mode 100644 index 0000000000..84914165e3 --- /dev/null +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py @@ -0,0 +1,46 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_integer_id, validate_url_with_path + + +class MLflowConfig(BaseTracingConfig): + """ + Model class for MLflow tracing config. + """ + + tracking_uri: str = "http://localhost:5000" + experiment_id: str = "0" # Default experiment id in MLflow is 0 + username: str | None = None + password: str | None = None + + @field_validator("tracking_uri") + @classmethod + def tracking_uri_validator(cls, v, info: ValidationInfo): + if isinstance(v, str) and v.startswith("databricks"): + raise ValueError( + "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." + ) + return validate_url_with_path(v, "http://localhost:5000") + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) + + +class DatabricksConfig(BaseTracingConfig): + """ + Model class for Databricks (Databricks-managed MLflow) tracing config. + """ + + experiment_id: str + host: str + client_id: str | None = None + client_secret: str | None = None + personal_access_token: str | None = None + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py similarity index 99% rename from api/core/ops/mlflow_trace/mlflow_trace.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py index c070a937be..4e4c45a532 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py @@ -4,7 +4,6 @@ from datetime import datetime, timedelta from typing import Any, cast import mlflow -from graphon.enums import BuiltinNodeTypes from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace @@ -12,7 +11,6 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -25,7 +23,9 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.utils import JSON_DICT_ADAPTER +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes from models import EndUser from models.workflow import WorkflowNodeExecutionModel diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py similarity index 98% rename from api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py rename to api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py index afc5726ede..20211456e3 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module.""" +"""Comprehensive tests for dify_trace_mlflow.mlflow_trace module.""" from __future__ import annotations @@ -9,8 +9,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig +from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -20,7 +21,6 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds from graphon.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── @@ -179,7 +179,7 @@ def _make_node(**overrides): @pytest.fixture def mock_mlflow(): - with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock: + with patch("dify_trace_mlflow.mlflow_trace.mlflow") as mock: yield mock @@ -187,10 +187,10 @@ def mock_mlflow(): def mock_tracing(): """Patch all MLflow tracing functions used by the module.""" with ( - patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start, - patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update, - patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set, - patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach, + patch("dify_trace_mlflow.mlflow_trace.start_span_no_context") as mock_start, + patch("dify_trace_mlflow.mlflow_trace.update_current_trace") as mock_update, + patch("dify_trace_mlflow.mlflow_trace.set_span_in_context") as mock_set, + patch("dify_trace_mlflow.mlflow_trace.detach_span_from_context") as mock_detach, ): yield { "start": mock_start, @@ -202,7 +202,7 @@ def mock_tracing(): @pytest.fixture def mock_db(): - with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock: + with patch("dify_trace_mlflow.mlflow_trace.db") as mock: yield mock diff --git a/api/providers/trace/trace-opik/pyproject.toml b/api/providers/trace/trace-opik/pyproject.toml new file mode 100644 index 0000000000..874997168e --- /dev/null +++ b/api/providers/trace/trace-opik/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-opik" +version = "0.0.1" +dependencies = [ + "opik~=1.11.2", +] +description = "Dify ops tracing provider (Opik)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py b/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/config.py b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py new file mode 100644 index 0000000000..c16ff1d903 --- /dev/null +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py @@ -0,0 +1,25 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class OpikConfig(BaseTracingConfig): + """ + Model class for Opik tracing config. + """ + + api_key: str | None = None + project: str | None = None + workspace: str | None = None + url: str = "https://www.comet.com/opik/api/" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "Default Project") + + @field_validator("url") + @classmethod + def url_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py similarity index 99% rename from api/core/ops/opik_trace/opik_trace.py rename to api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py index e0c7b9bfe5..2d124ac989 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py @@ -5,13 +5,11 @@ import uuid from datetime import datetime, timedelta from typing import Any, cast -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -24,7 +22,9 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory +from dify_trace_opik.config import OpikConfig from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed b/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py index c02ac413f2..eefed3c78c 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py @@ -5,8 +5,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from dify_trace_opik.config import OpikConfig +from dify_trace_opik.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -17,7 +18,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -37,7 +37,7 @@ def opik_config(): @pytest.fixture def trace_instance(opik_config, monkeypatch): mock_client = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client) instance = OpikDataTrace(opik_config) return instance @@ -67,7 +67,7 @@ def test_prepare_opik_uuid(): def test_init(opik_config, monkeypatch): mock_opik = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik) monkeypatch.setenv("FILES_URL", "http://test.url") instance = OpikDataTrace(opik_config) @@ -166,8 +166,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) node_llm = MagicMock() node_llm.id = LLM_NODE_ID @@ -203,7 +203,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -250,13 +250,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -286,8 +286,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e", error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -373,7 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_opik.opik_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) trace_instance.add_span = MagicMock() @@ -658,9 +658,9 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..5a54b70bba --- /dev/null +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py @@ -0,0 +1,48 @@ +import pytest +from dify_trace_opik.config import OpikConfig +from pydantic import ValidationError + + +class TestOpikConfig: + """Test cases for OpikConfig""" + + def test_valid_config(self): + """Test valid Opik configuration""" + config = OpikConfig( + api_key="test_key", + project="test_project", + workspace="test_workspace", + url="https://custom.comet.com/opik/api/", + ) + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.workspace == "test_workspace" + assert config.url == "https://custom.comet.com/opik/api/" + + def test_default_values(self): + """Test default values are set correctly""" + config = OpikConfig() + assert config.api_key is None + assert config.project is None + assert config.workspace is None + assert config.url == "https://www.comet.com/opik/api/" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = OpikConfig(project="") + assert config.project == "Default Project" + + def test_url_validation_empty(self): + """Test URL validation with empty value""" + config = OpikConfig(url="") + assert config.url == "https://www.comet.com/opik/api/" + + def test_url_validation_missing_suffix(self): + """Test URL validation requires /api/ suffix""" + with pytest.raises(ValidationError, match="URL should end with /api/"): + OpikConfig(url="https://custom.comet.com/opik/") + + def test_url_validation_invalid_scheme(self): + """Test URL validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + OpikConfig(url="ftp://custom.comet.com/opik/api/") diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py similarity index 94% rename from api/tests/unit_tests/core/ops/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py index ad9d0846be..fba290f5b8 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py @@ -14,8 +14,9 @@ import uuid from datetime import datetime from unittest.mock import MagicMock, patch +from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid + from core.ops.entities.trace_entity import TraceTaskName, WorkflowTraceInfo -from core.ops.opik_trace.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid # A stable UUID4 used as the workflow_run_id throughout all tests. _WORKFLOW_RUN_ID = "a3f1b2c4-d5e6-4f78-9a0b-c1d2e3f4a5b6" @@ -56,8 +57,8 @@ def _make_workflow_trace_info( def _make_opik_trace_instance() -> OpikDataTrace: """Construct an OpikDataTrace with the Opik SDK client mocked out.""" - with patch("core.ops.opik_trace.opik_trace.Opik"): - from core.ops.entities.config_entity import OpikConfig + with patch("dify_trace_opik.opik_trace.Opik"): + from dify_trace_opik.config import OpikConfig config = OpikConfig(api_key="key", project="test-project", url="https://www.comet.com/opik/api/") instance = OpikDataTrace(config) @@ -133,10 +134,10 @@ class TestWorkflowTraceWithoutMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): @@ -265,10 +266,10 @@ class TestWorkflowTraceWithMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): diff --git a/api/providers/trace/trace-tencent/pyproject.toml b/api/providers/trace/trace-tencent/pyproject.toml new file mode 100644 index 0000000000..eab06fc708 --- /dev/null +++ b/api/providers/trace/trace-tencent/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-tencent" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Tencent APM)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/tencent_trace/client.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py similarity index 100% rename from api/core/ops/tencent_trace/client.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py new file mode 100644 index 0000000000..398e6c55a8 --- /dev/null +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py @@ -0,0 +1,30 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig + + +class TencentConfig(BaseTracingConfig): + """ + Tencent APM tracing config + """ + + token: str + endpoint: str + service_name: str + + @field_validator("token") + @classmethod + def token_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("Token cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") + + @field_validator("service_name") + @classmethod + def service_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") diff --git a/api/core/ops/tencent_trace/entities/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/entities/__init__.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py diff --git a/api/core/ops/tencent_trace/entities/semconv.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py similarity index 100% rename from api/core/ops/tencent_trace/entities/semconv.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py diff --git a/api/core/ops/tencent_trace/entities/tencent_trace_entity.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py similarity index 100% rename from api/core/ops/tencent_trace/entities/tencent_trace_entity.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed b/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py similarity index 98% rename from api/core/ops/tencent_trace/span_builder.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py index f79095d966..763a85ffd7 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py @@ -6,8 +6,6 @@ import json import logging from datetime import datetime -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import Status, StatusCode from core.ops.entities.trace_entity import ( @@ -16,7 +14,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_tencent.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, GEN_AI_IS_ENTRY, @@ -40,9 +39,10 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.utils import TencentTraceUtils -from core.rag.models.document import Document +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.utils import TencentTraceUtils +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py similarity index 98% rename from api/core/ops/tencent_trace/tencent_trace.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py index 84f54d8a5a..cfcf6b307e 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py @@ -4,15 +4,10 @@ Tencent APM tracing implementation with separated concerns import logging -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from graphon.nodes import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -23,12 +18,17 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.client import TencentTraceClient -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.span_builder import TencentSpanBuilder -from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from dify_trace_tencent.client import TencentTraceClient +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.span_builder import TencentSpanBuilder +from dify_trace_tencent.utils import TencentTraceUtils from extensions.ext_database import db +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.nodes import BuiltinNodeTypes from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/utils.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py similarity index 100% rename from api/core/ops/tencent_trace/utils.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py similarity index 98% rename from api/tests/unit_tests/core/ops/tencent_trace/test_client.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py index 870c18e53e..1e656e2462 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py @@ -8,13 +8,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from dify_trace_tencent import client as client_module +from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version +from dify_trace_tencent.entities.tencent_trace_entity import SpanData from opentelemetry.sdk.trace import Event from opentelemetry.trace import Status, StatusCode -from core.ops.tencent_trace import client as client_module -from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData - metric_reader_instances: list[DummyMetricReader] = [] meter_provider_instances: list[DummyMeterProvider] = [] diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py similarity index 89% rename from api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py index 6113e5c6c8..e850a801f3 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py @@ -1,15 +1,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch -from opentelemetry.trace import StatusCode - -from core.ops.entities.trace_entity import ( - DatasetRetrievalTraceInfo, - MessageTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.ops.tencent_trace.entities.semconv import ( +from dify_trace_tencent.entities.semconv import ( GEN_AI_IS_ENTRY, GEN_AI_IS_STREAMING_REQUEST, GEN_AI_MODEL_NAME, @@ -23,7 +15,15 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.span_builder import TencentSpanBuilder +from dify_trace_tencent.span_builder import TencentSpanBuilder +from opentelemetry.trace import StatusCode + +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + MessageTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) from core.rag.models.document import Document from graphon.entities import WorkflowNodeExecution from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -31,7 +31,7 @@ from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutio class TestTencentSpanBuilder: def test_get_time_nanoseconds(self): - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: mock_convert.return_value = 123456789 dt = datetime.now() result = TencentSpanBuilder._get_time_nanoseconds(dt) @@ -48,7 +48,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {"answer": "world"} trace_info.metadata = {"conversation_id": "conv_id"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -70,7 +70,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {} trace_info.metadata = {} # No conversation_id - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 1 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -98,7 +98,7 @@ class TestTencentSpanBuilder: } node_execution.outputs = {"text": "world"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -123,7 +123,7 @@ class TestTencentSpanBuilder: "usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}, } - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -142,7 +142,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {"conversation_id": "conv_id"} trace_info.is_streaming_request = True - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -162,7 +162,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {} trace_info.is_streaming_request = False - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -182,7 +182,7 @@ class TestTencentSpanBuilder: trace_info.tool_inputs = {"i": 2} trace_info.tool_outputs = "result" - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 101 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1) @@ -204,7 +204,7 @@ class TestTencentSpanBuilder: ) trace_info.documents = [doc] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -222,7 +222,7 @@ class TestTencentSpanBuilder: trace_info.end_time = datetime.now() trace_info.documents = [] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -264,7 +264,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -286,7 +286,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -307,7 +307,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -329,7 +329,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -350,7 +350,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 505 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py similarity index 89% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py index 7afd0b824a..a91a0aa558 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py @@ -2,8 +2,9 @@ import logging from unittest.mock import MagicMock, patch import pytest +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.tencent_trace import TencentDataTrace -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -13,7 +14,6 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.tencent_trace import TencentDataTrace from graphon.entities import WorkflowNodeExecution from graphon.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin @@ -28,19 +28,19 @@ def tencent_config(): @pytest.fixture def mock_trace_client(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceClient") as mock: yield mock @pytest.fixture def mock_span_builder(): - with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentSpanBuilder") as mock: yield mock @pytest.fixture def mock_trace_utils(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceUtils") as mock: yield mock @@ -198,9 +198,9 @@ class TestTencentDataTrace: trace_info.workflow_run_id = "run-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.workflow_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace") @@ -230,9 +230,9 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.message_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace") @@ -262,9 +262,9 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.tool_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace") @@ -294,22 +294,22 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.dataset_retrieval_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace") def test_suggested_question_trace(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace") def test_suggested_question_trace_exception(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace") @@ -342,7 +342,7 @@ class TestTencentDataTrace: with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]): with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) # The exception should be caught by the outer handler since convert_to_span_id is called first mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -351,7 +351,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -381,7 +381,7 @@ class TestTencentDataTrace: node.id = "n1" mock_span_builder.build_workflow_llm_span.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456) assert result is None mock_log.assert_called_once() @@ -403,15 +403,13 @@ class TestTencentDataTrace: mock_executions = [MagicMock()] - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.side_effect = [app, account, tenant_join] - with patch( - "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" - ) as mock_repo: + with patch("dify_trace_tencent.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository") as mock_repo: mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) @@ -423,7 +421,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {} - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -432,14 +430,14 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {"app_id": "app-1"} - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() # Ensure init_app is mocked mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.return_value = None - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -449,8 +447,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "tenant-1" trace_info.metadata = {"user_id": "user-1"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")): - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")): + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() mock_db.engine = MagicMock() @@ -476,8 +474,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "t" trace_info.metadata = {"user_id": "u"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: user_id = tencent_data_trace._get_user_id(trace_info) assert user_id == "unknown" mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID") @@ -519,7 +517,7 @@ class TestTencentDataTrace: node.process_data = None node.outputs = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_llm_metrics(node) # Should not crash @@ -557,7 +555,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.metadata = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_llm_metrics(trace_info) # Should not crash @@ -609,7 +607,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_workflow_trace_duration(trace_info) def test_record_message_trace_duration(self, tencent_data_trace): @@ -631,7 +629,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.start_time = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_trace_duration(trace_info) def test_del(self, tencent_data_trace): @@ -641,6 +639,6 @@ class TestTencentDataTrace: def test_del_exception(self, tencent_data_trace): tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.__del__() mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup") diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py similarity index 88% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py index ef28d18e20..63c6d680d7 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py @@ -8,10 +8,9 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from dify_trace_tencent.utils import TencentTraceUtils from opentelemetry.trace import Link, TraceFlags -from core.ops.tencent_trace.utils import TencentTraceUtils - def test_convert_to_trace_id_with_valid_uuid() -> None: uuid_str = "12345678-1234-5678-1234-567812345678" @@ -20,7 +19,7 @@ def test_convert_to_trace_id_with_valid_uuid() -> None: def test_convert_to_trace_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int uuid4_mock.assert_called_once() @@ -45,7 +44,7 @@ def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None: def test_convert_to_span_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: span_id = TencentTraceUtils.convert_to_span_id(None, "workflow") assert isinstance(span_id, int) uuid4_mock.assert_called_once() @@ -58,7 +57,7 @@ def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None: def test_generate_span_id_skips_invalid_span_id() -> None: with patch( - "core.ops.tencent_trace.utils.random.getrandbits", + "dify_trace_tencent.utils.random.getrandbits", side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42], ) as bits_mock: assert TencentTraceUtils.generate_span_id() == 42 @@ -75,7 +74,7 @@ def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None: fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) expected = int(fixed.timestamp() * 1e9) - with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock: + with patch("dify_trace_tencent.utils.datetime") as datetime_mock: datetime_mock.now.return_value = fixed assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected datetime_mock.now.assert_called_once() @@ -100,7 +99,7 @@ def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: i @pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None]) def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None: fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type] assert link.context.trace_id == fallback_uuid.int uuid4_mock.assert_called_once() diff --git a/api/providers/trace/trace-weave/pyproject.toml b/api/providers/trace/trace-weave/pyproject.toml new file mode 100644 index 0000000000..ba449f2a93 --- /dev/null +++ b/api/providers/trace/trace-weave/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-weave" +version = "0.0.1" +dependencies = [ + "weave>=0.52.36", +] +description = "Dify ops tracing provider (Weave)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/config.py b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py new file mode 100644 index 0000000000..5942bd57fe --- /dev/null +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py @@ -0,0 +1,29 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class WeaveConfig(BaseTracingConfig): + """ + Model class for Weave tracing config. + """ + + api_key: str + entity: str | None = None + project: str + endpoint: str = "https://trace.wandb.ai" + host: str | None = None + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # Weave only allows HTTPS for endpoint + return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + if v is not None and v.strip() != "": + return validate_url(v, v, allowed_schemes=("https", "http")) + return v diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py similarity index 100% rename from api/core/ops/weave_trace/entities/weave_trace_entity.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed b/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py similarity index 99% rename from api/core/ops/weave_trace/weave_trace.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py index 8d9ba4694d..4292cbf0f1 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py @@ -6,7 +6,6 @@ from typing import Any, cast import wandb import weave -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from sqlalchemy.orm import sessionmaker from weave.trace_server.trace_server_interface import ( CallEndReq, @@ -18,7 +17,6 @@ from weave.trace_server.trace_server_interface import ( ) from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -30,9 +28,11 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..eeb1fe1d87 --- /dev/null +++ b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py @@ -0,0 +1,61 @@ +import pytest +from dify_trace_weave.config import WeaveConfig +from pydantic import ValidationError + + +class TestWeaveConfig: + """Test cases for WeaveConfig""" + + def test_valid_config(self): + """Test valid Weave configuration""" + config = WeaveConfig( + api_key="test_key", + entity="test_entity", + project="test_project", + endpoint="https://custom.wandb.ai", + host="https://custom.host.com", + ) + assert config.api_key == "test_key" + assert config.entity == "test_entity" + assert config.project == "test_project" + assert config.endpoint == "https://custom.wandb.ai" + assert config.host == "https://custom.host.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = WeaveConfig(api_key="key", project="project") + assert config.entity is None + assert config.endpoint == "https://trace.wandb.ai" + assert config.host is None + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + WeaveConfig() + + with pytest.raises(ValidationError): + WeaveConfig(api_key="key") + + with pytest.raises(ValidationError): + WeaveConfig(project="project") + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") + + def test_host_validation_optional(self): + """Test host validation is optional but validates when provided""" + config = WeaveConfig(api_key="key", project="project", host=None) + assert config.host is None + + config = WeaveConfig(api_key="key", project="project", host="") + assert config.host == "" + + config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") + assert config.host == "https://valid.host.com" + + def test_host_validation_invalid_scheme(self): + """Test host validation rejects invalid schemes when provided""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py similarity index 97% rename from api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py rename to api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py index 531c7de05f..6028d0c550 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.weave_trace.weave_trace module.""" +"""Comprehensive tests for dify_trace_weave.weave_trace module.""" from __future__ import annotations @@ -7,9 +7,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel +from dify_trace_weave.weave_trace import WeaveDataTrace from weave.trace_server.trace_server_interface import TraceStatus -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -20,8 +22,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.ops.weave_trace.weave_trace import WeaveDataTrace from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -191,14 +191,14 @@ def _make_node(**overrides): @pytest.fixture def mock_wandb(): - with patch("core.ops.weave_trace.weave_trace.wandb") as mock: + with patch("dify_trace_weave.weave_trace.wandb") as mock: mock.login.return_value = True yield mock @pytest.fixture def mock_weave(): - with patch("core.ops.weave_trace.weave_trace.weave") as mock: + with patch("dify_trace_weave.weave_trace.weave") as mock: client = MagicMock() client.entity = "my-entity" client.project = "my-project" @@ -307,7 +307,7 @@ class TestGetProjectUrl: monkeypatch.setattr(trace_instance, "entity", None) monkeypatch.setattr(trace_instance, "project_name", None) # Force an error by making string formatting fail - with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger: + with patch("dify_trace_weave.weave_trace.logger") as mock_logger: # Simulate exception via property original_entity = trace_instance.entity trace_instance.entity = None @@ -594,9 +594,9 @@ class TestWorkflowTrace: mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) return repo def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch): @@ -703,8 +703,8 @@ class TestWorkflowTrace: def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch): """Raises ValueError when app_id is missing from metadata.""" - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) trace_info = _make_workflow_trace_info( message_id=None, @@ -802,7 +802,7 @@ class TestMessageTrace: def test_basic_message_trace(self, trace_instance, monkeypatch): """message_trace creates message run and llm child run.""" monkeypatch.setattr( - "core.ops.weave_trace.weave_trace.db.session.get", + "dify_trace_weave.weave_trace.db.session.get", lambda model, pk: None, ) @@ -824,7 +824,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -846,7 +846,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = end_user - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -866,7 +866,7 @@ class TestMessageTrace: """message_trace handles when from_end_user_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -884,7 +884,7 @@ class TestMessageTrace: """trace_id falls back to message_id when trace_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -899,7 +899,7 @@ class TestMessageTrace: """message_trace handles file_list=None gracefully.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() diff --git a/api/pyproject.toml b/api/pyproject.toml index af8eb864b0..a21d98f7bf 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -32,9 +32,6 @@ dependencies = [ "flask-restx>=1.3.2,<2.0.0", "google-cloud-aiplatform>=1.147.0,<2.0.0", "httpx[socks]>=0.28.1,<1.0.0", - "langfuse>=4.2.0,<5.0.0", - "langsmith>=0.7.31,<1.0.0", - "mlflow-skinny>=3.11.1,<4.0.0", "opentelemetry-distro>=0.62b0,<1.0.0", "opentelemetry-instrumentation-celery>=0.62b0,<1.0.0", "opentelemetry-instrumentation-flask>=0.62b0,<1.0.0", @@ -44,15 +41,12 @@ dependencies = [ "opentelemetry-propagator-b3>=1.41.0,<2.0.0", "readabilipy>=0.3.0,<1.0.0", "resend>=2.27.0,<3.0.0", - "weave>=0.52.36,<1.0.0", # Emerging: newer and fast-moving, use compatible pins - "arize-phoenix-otel~=0.15.0", "fastopenapi[flask]~=0.7.0", "graphon~=0.1.2", "httpx-sse~=0.4.0", "json-repair~=0.59.2", - "opik~=1.11.2", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -61,8 +55,8 @@ dependencies = [ packages = [] [tool.uv.workspace] -members = ["providers/vdb/*"] -exclude = ["providers/vdb/__pycache__"] +members = ["providers/vdb/*", "providers/trace/*"] +exclude = ["providers/vdb/__pycache__", "providers/trace/__pycache__"] [tool.uv.sources] dify-vdb-alibabacloud-mysql = { workspace = true } @@ -95,9 +89,17 @@ dify-vdb-upstash = { workspace = true } dify-vdb-vastbase = { workspace = true } dify-vdb-vikingdb = { workspace = true } dify-vdb-weaviate = { workspace = true } +dify-trace-aliyun = { workspace = true } +dify-trace-arize-phoenix = { workspace = true } +dify-trace-langfuse = { workspace = true } +dify-trace-langsmith = { workspace = true } +dify-trace-mlflow = { workspace = true } +dify-trace-opik = { workspace = true } +dify-trace-tencent = { workspace = true } +dify-trace-weave = { workspace = true } [tool.uv] -default-groups = ["storage", "tools", "vdb-all"] +default-groups = ["storage", "tools", "vdb-all", "trace-all"] package = false override-dependencies = [ "pyarrow>=18.0.0", @@ -272,6 +274,25 @@ vdb-weaviate = ["dify-vdb-weaviate"] # Optional client used by some tests / integrations (not a vector backend plugin) vdb-xinference = ["xinference-client>=2.4.0"] +trace-all = [ + "dify-trace-aliyun", + "dify-trace-arize-phoenix", + "dify-trace-langfuse", + "dify-trace-langsmith", + "dify-trace-mlflow", + "dify-trace-opik", + "dify-trace-tencent", + "dify-trace-weave", +] +trace-aliyun = ["dify-trace-aliyun"] +trace-arize-phoenix = ["dify-trace-arize-phoenix"] +trace-langfuse = ["dify-trace-langfuse"] +trace-langsmith = ["dify-trace-langsmith"] +trace-mlflow = ["dify-trace-mlflow"] +trace-opik = ["dify-trace-opik"] +trace-tencent = ["dify-trace-tencent"] +trace-weave = ["dify-trace-weave"] + [tool.pyrefly] project-includes = ["."] project-excludes = [".venv", "migrations/"] diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index 3e5ece1fcf..fbbca24558 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -34,12 +34,12 @@ core/external_data_tool/api/api.py core/llm_generator/llm_generator.py core/llm_generator/output_parser/structured_output.py core/mcp/mcp_client.py -core/ops/aliyun_trace/data_exporter/traceclient.py -core/ops/arize_phoenix_trace/arize_phoenix_trace.py -core/ops/mlflow_trace/mlflow_trace.py +providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py +providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py core/ops/ops_trace_manager.py -core/ops/tencent_trace/client.py -core/ops/tencent_trace/utils.py +providers/trace/trace-tencent/src/dify_trace_tencent/client.py +providers/trace/trace-tencent/src/dify_trace_tencent/utils.py core/plugin/backwards_invocation/base.py core/plugin/backwards_invocation/model.py core/prompt/utils/extract_thread_messages.py diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index c4582e891d..ac0e2a3a53 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -5,7 +5,8 @@ ".venv", "migrations/", "core/rag", - "providers/", + "providers/vdb/", + "providers/trace/*/tests", ], "typeCheckingMode": "strict", "allowedUntypedLibraries": [ diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 100589804c..72b38e7906 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -38,11 +38,11 @@ from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol, TypedDict -from graphon.entities.pause_reason import PauseReason -from graphon.enums import WorkflowType from sqlalchemy.orm import Session from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index d5c6a203b1..44735eb769 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -10,11 +10,11 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol, cast -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index b760696c5e..474b200fc5 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -28,15 +28,15 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.nodes.human_input.entities import FormDefinition from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date from libs.infinite_scroll_pagination import InfiniteScrollPagination diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index feba5f7eb6..67f8795d3f 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -7,9 +7,6 @@ from collections import defaultdict from collections.abc import Sequence from typing import Any -from graphon.nodes.human_input.entities import FormDefinition -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker @@ -21,6 +18,9 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 74b800606d..78806927bc 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -10,12 +10,6 @@ from uuid import uuid4 import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from packaging import version from packaging.version import parse as parse_version from pydantic import BaseModel @@ -35,6 +29,12 @@ from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerSc from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index 0842e9d3e7..6e9d6b1c73 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -5,11 +5,10 @@ like stopping tasks, handling both legacy Redis flag mechanism and new GraphEngine command channel mechanism. """ -from graphon.graph_engine.manager import GraphEngineManager - from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager from models.model import AppMode diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1c7027efb4..60948e652b 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,12 +5,12 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context -from graphon.model_runtime.entities.model_entities import ModelType from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message from services.errors.audio import ( diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index ea12e40420..dcc93b4b0f 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,7 +6,6 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app -from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, sessionmaker @@ -14,6 +13,7 @@ from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.account import Tenant from models.model import ( App, diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index f5085af59b..ee8a1c4edd 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -3,7 +3,6 @@ import logging from collections.abc import Callable, Sequence from typing import Any -from graphon.variables.types import SegmentType from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -13,6 +12,7 @@ from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from factories import variable_factory +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index 95a8951951..287d513f48 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ -from graphon.variables.variables import VariableBase from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from graphon.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 364c4a86a0..416bc8cef9 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,7 +3,6 @@ import time from collections.abc import Mapping from typing import Any -from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy import delete, func, select, update from sqlalchemy.orm import Session, sessionmaker @@ -18,6 +17,7 @@ from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index a944ef6acd..6679c08ebd 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,15 +1,6 @@ from collections.abc import Sequence from enum import StrEnum -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderHelpEntity, - SimpleProviderEntity, -) from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config @@ -24,6 +15,15 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) from models.provider import ProviderType diff --git a/api/services/file_service.py b/api/services/file_service.py index 79a935de4b..52da2a7951 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -8,7 +8,6 @@ from tempfile import NamedTemporaryFile from typing import Literal from zipfile import ZIP_DEFLATED, ZipFile -from graphon.file import helpers as file_helpers from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound @@ -24,6 +23,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 77576fa4c0..68ef67dec1 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -4,7 +4,6 @@ from dataclasses import dataclass, field from enum import StrEnum from typing import Protocol -from graphon.runtime import VariablePool from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker @@ -18,6 +17,7 @@ from core.workflow.human_input_compat import ( ) from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import VariablePool from libs.email_template_renderer import render_email_template from models import Account, TenantAccountJoin from services.feature_service import FeatureService diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 02a6620fc7..76598d31ac 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -3,12 +3,6 @@ from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any -from graphon.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -17,6 +11,12 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) +from graphon.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index b652e049ce..c269346f5f 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -2,12 +2,6 @@ import json import logging from typing import Any, TypedDict -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ModelCredentialSchema, - ProviderCredentialSchema, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import or_, select from constants import HIDDEN_VALUE @@ -18,6 +12,12 @@ from core.model_manager import LBModelManager from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index aa7456dcd3..8c9a81af87 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -50,7 +50,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(language, {}) @classmethod @@ -60,5 +60,5 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param template_id: Template ID :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(template_id) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 0ffbef8365..9d446f6d4b 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, TypedDict import yaml from sqlalchemy import select @@ -10,6 +10,30 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class CustomizedTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + position: int + chunk_structure: str + + +class CustomizedTemplatesResultDict(TypedDict): + pipeline_templates: list[CustomizedTemplateItemDict] + + +class CustomizedTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + created_by: str + + class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval recommended app from database @@ -17,12 +41,10 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): def get_pipeline_templates(self, language: str) -> dict[str, Any]: _, current_tenant_id = current_account_with_tenant() - result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) - return result + return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.CUSTOMIZED @@ -40,9 +62,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): .where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) ).all() - recommended_pipelines_results = [] + recommended_pipelines_results: list[CustomizedTemplateItemDict] = [] for pipeline_customized_template in pipeline_customized_templates: - recommended_pipeline_result = { + recommended_pipeline_result: CustomizedTemplateItemDict = { "id": pipeline_customized_template.id, "name": pipeline_customized_template.name, "description": pipeline_customized_template.description, diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 073eed221c..2964537c35 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, TypedDict import yaml from sqlalchemy import select @@ -9,18 +9,41 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class PipelineTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + copyright: str + privacy_policy: str + position: int + chunk_structure: str + + +class PipelineTemplatesResultDict(TypedDict): + pipeline_templates: list[PipelineTemplateItemDict] + + +class PipelineTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + + class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval pipeline template from database """ def get_pipeline_templates(self, language: str) -> dict[str, Any]: - result = self.fetch_pipeline_templates_from_db(language) - return result + return self.fetch_pipeline_templates_from_db(language) def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.DATABASE @@ -39,9 +62,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): ).all() ) - recommended_pipelines_results = [] + recommended_pipelines_results: list[PipelineTemplateItemDict] = [] for pipeline_built_in_template in pipeline_built_in_templates: - recommended_pipeline_result = { + recommended_pipeline_result: PipelineTemplateItemDict = { "id": pipeline_built_in_template.id, "name": pipeline_built_in_template.name, "description": pipeline_built_in_template.description, diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index d5ef745bec..9565ac46cc 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -17,21 +17,18 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result: dict[str, Any] | None try: - result = self.fetch_pipeline_template_detail_from_dify_official(template_id) + return self.fetch_pipeline_template_detail_from_dify_official(template_id) except Exception as e: logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) def get_pipeline_templates(self, language: str) -> dict[str, Any]: try: - result = self.fetch_pipeline_templates_from_dify_official(language) + return self.fetch_pipeline_templates_from_dify_official(language) except Exception as e: logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) def get_type(self) -> str: return PipelineTemplateType.REMOTE diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 605689226a..968600d1bc 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -9,15 +9,6 @@ from typing import Any, cast from uuid import uuid4 from flask_login import current_user -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus -from graphon.errors import WorkflowNodeRunFailedError -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.runtime import VariablePool -from graphon.variables.variables import Variable, VariableBase from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker @@ -53,6 +44,15 @@ from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.dataset import ( # type: ignore diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 7dd86f1581..f315d053cb 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -13,12 +13,6 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from flask_login import current_user -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from packaging import version from pydantic import BaseModel from sqlalchemy import select @@ -33,6 +27,12 @@ from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index ab60986bfe..21be411bea 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -27,13 +27,13 @@ from dataclasses import dataclass, field from typing import Any, TypedDict import click -from graphon.enums import WorkflowType from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from graphon.enums import WorkflowType from libs.archive_storage import ( ArchiveStorage, ArchiveStorageNotConfiguredError, diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 3bfa221528..5ff2c21749 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -2,9 +2,9 @@ import json import logging from typing import Any, TypedDict, cast -from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx import get from sqlalchemy import select +from sqlalchemy.orm import sessionmaker from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool_runtime import ToolRuntime @@ -16,11 +16,13 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, ) +from core.tools.errors import ApiToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -116,71 +118,85 @@ class ApiToolManageService: privacy_policy: str, custom_disclaimer: str, labels: list[str], - ): + ) -> dict[str, Any]: """ - create api tool provider + Create a new API tool provider. + + :param user_id: The ID of the user creating the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :param icon: The icon configuration for the provider. + :param credentials: The credentials for the provider. + :param schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :param privacy_policy: The privacy policy URL or text. + :param custom_disclaimer: Custom disclaimer text. + :param labels: A list of labels for the provider. + :return: A dictionary indicating the result status. """ + provider_name = provider_name.strip() # check if the provider exists - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + # Create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if provider is not None: - raise ValueError(f"provider {provider_name} already exists") + if provider is not None: + raise ValueError(f"provider {provider_name} already exists") - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if len(tool_bundles) > 100: - raise ValueError("the number of apis should be less than 100") + if len(tool_bundles) > 100: + raise ValueError("the number of apis should be less than 100") - # create db provider - db_provider = ApiToolProvider( - tenant_id=tenant_id, - user_id=user_id, - name=provider_name, - icon=json.dumps(icon), - schema=schema, - description=extra_info.get("description", ""), - schema_type_str=schema_type, - tools_str=json.dumps(jsonable_encoder(tool_bundles)), - credentials_str="{}", - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - ) + # create API tool provider + api_tool_provider = ApiToolProvider( + tenant_id=tenant_id, + user_id=user_id, + name=provider_name, + icon=json.dumps(icon), + schema=schema, + description=extra_info.get("description", ""), + schema_type_str=schema_type, + tools_str=json.dumps(jsonable_encoder(tool_bundles)), + credentials_str="{}", + privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer, + ) - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # create provider entity - provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # create provider entity + provider_controller = ApiToolProviderController.from_db(api_tool_provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - # encrypt credentials - encrypter, _ = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) - db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) + # encrypt credentials + encrypter, _ = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) + api_tool_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) - db.session.add(db_provider) - db.session.commit() + _session.add(api_tool_provider) - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels) + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels, _session) return {"result": "success"} @@ -212,16 +228,25 @@ class ApiToolManageService: @staticmethod def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]: """ - list api tool provider tools + List tools provided by a specific API tool provider. + + :param user_id: The ID of the user requesting the list. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :return: A list of ToolApiEntity objects. """ - provider: ApiToolProvider | None = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + + # create new session with automatic transaction management + provider: ApiToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) if provider is None: raise ValueError(f"you have not added provider {provider_name}") @@ -251,103 +276,133 @@ class ApiToolManageService: privacy_policy: str | None, custom_disclaimer: str, labels: list[str], - ): + ) -> dict[str, Any]: """ - update api tool provider + Update an existing API tool provider. + + :param user_id: The ID of the user updating the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The new name of the API tool provider. + :param original_provider: The original name of the API tool provider. + :param icon: The icon configuration for the provider. + :param credentials: The credentials for the provider. + :param _schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :param privacy_policy: The privacy policy URL or text. + :param custom_disclaimer: Custom disclaimer text. + :param labels: A list of labels for the provider. + :return: A dictionary indicating the result status. """ + provider_name = provider_name.strip() # check if the provider exists - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == original_provider, + # create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ) + .limit(1) ) - .limit(1) - ) - if provider is None: - raise ValueError(f"api provider {provider_name} does not exists") - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + if provider is None: + raise ApiToolProviderNotFoundError(provider_name=original_provider, tenant_id=tenant_id) - # update db provider - provider.name = provider_name - provider.icon = json.dumps(icon) - provider.schema = schema - provider.description = extra_info.get("description", "") - provider.schema_type_str = schema_type - provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) - provider.privacy_policy = privacy_policy - provider.custom_disclaimer = custom_disclaimer + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + # update db provider + provider.name = provider_name + provider.icon = json.dumps(icon) + provider.schema = schema + provider.description = extra_info.get("description", "") + provider.schema_type_str = schema_type + provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) + provider.privacy_policy = privacy_policy + provider.custom_disclaimer = custom_disclaimer - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - # create provider entity - provider_controller = ApiToolProviderController.from_db(provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # get original credentials if exists - encrypter, cache = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) + # create provider entity + provider_controller = ApiToolProviderController.from_db(provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - original_credentials = encrypter.decrypt(provider.credentials) - masked_credentials = encrypter.mask_plugin_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] + # get original credentials if exists + encrypter, cache = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) - credentials = dict(encrypter.encrypt(credentials)) - provider.credentials_str = json.dumps(credentials) + original_credentials = encrypter.decrypt(provider.credentials) + masked_credentials = encrypter.mask_plugin_credentials(original_credentials) - db.session.add(provider) - db.session.commit() + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = original_credentials[name] + + credentials = dict(encrypter.encrypt(credentials)) + provider.credentials_str = json.dumps(credentials) + + _session.add(provider) + + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels, _session) # delete cache cache.delete() - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels) - return {"result": "success"} @staticmethod def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + Delete an API tool provider. + + :param user_id: The ID of the user performing the deletion operation. + :param tenant_id: The ID of the workspace/tenant where the provider belongs. + :param provider_name: The unique name of the API tool provider to be deleted. + :raises ValueError: If the specified provider does not exist in the tenant. + :return: A dictionary indicating the result status. """ - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + + # create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if provider is None: - raise ValueError(f"you have not added provider {provider_name}") + if provider is None: + raise ValueError(f"you have not added provider {provider_name}") - db.session.delete(provider) - db.session.commit() + _session.delete(provider) return {"result": "success"} @staticmethod - def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str) -> dict[str, Any]: """ - get api tool provider + Get API tool provider details. + + :param user_id: The ID of the user requesting the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider: The name of the API tool provider. + :return: A dictionary containing the provider details. """ return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) @@ -360,10 +415,20 @@ class ApiToolManageService: parameters: dict[str, Any], schema_type: ApiProviderSchemaType, schema: str, - ): + ) -> dict[str, Any]: """ - test api tool before adding api tool provider + Test an API tool before adding the API tool provider. + + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :param tool_name: The name of the specific tool to test. + :param credentials: The credentials for the provider. + :param parameters: The parameters to pass to the tool. + :param schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :return: A dictionary containing the result or error message. """ + if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f"invalid schema type {schema_type}") @@ -377,18 +442,21 @@ class ApiToolManageService: if tool_bundle is None: raise ValueError(f"invalid tool name {tool_name}") - db_provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + # create new session with automatic transaction management to get the provider + provider: ApiToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if not db_provider: + if provider is None: # create a fake db provider - db_provider = ApiToolProvider( + provider = ApiToolProvider( tenant_id="", user_id="", name="", @@ -407,12 +475,12 @@ class ApiToolManageService: auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity - provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) + provider_controller = ApiToolProviderController.from_db(provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) # decrypt credentials - if db_provider.id: + if provider.id: encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, controller=provider_controller, @@ -443,14 +511,21 @@ class ApiToolManageService: @staticmethod def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]: """ - list api tools + List all API tools for a specific tenant. + + :param tenant_id: The ID of the workspace/tenant. + :return: A list of ToolProviderApiEntity objects. """ # get all api providers - db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() + # create new session with automatic transaction management + providers: list[ApiToolProvider] = [] + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + providers = list( + _session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() + ) result: list[ToolProviderApiEntity] = [] - - for provider in db_providers: + for provider in providers: # convert provider controller to user provider provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(provider_controller) diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index 5a5d13b96d..911331e357 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -5,7 +5,6 @@ from collections.abc import Mapping from typing import Any from flask import Request, Response -from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -21,6 +20,7 @@ from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_ from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from models.model import App from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription, WorkflowPluginTrigger diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 4d58a9cf12..c96050ce13 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any, overload +from configs import dify_config from graphon.file import File from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable from graphon.variables.segments import ( @@ -21,8 +22,6 @@ from graphon.variables.segments import ( ) from graphon.variables.utils import dumps_with_segments -from configs import dify_config - _MAX_DEPTH = 100 diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 9827c8dfbc..58193d75a9 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,6 +1,5 @@ import logging -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, select from core.model_manager import ModelInstance, ModelManager @@ -13,6 +12,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1582bcd46c..5dedb9e372 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,11 +1,6 @@ import json from typing import Any, TypedDict -from graphon.file import FileUploadConfig -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.input_entities import VariableEntity from sqlalchemy import select from core.app.app_config.entities import ( @@ -24,6 +19,11 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.app_event import app_was_created from extensions.ext_database import db +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig, IconType diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index fae5dea3cb..8afb565955 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -7,19 +7,6 @@ from datetime import datetime from enum import StrEnum from typing import Any, ClassVar, NotRequired, TypedDict -from graphon.enums import NodeType -from graphon.file import File -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.variable_assigner.common.helpers import get_updated_variables -from graphon.variable_loader import VariableLoader -from graphon.variables import Segment, StringSegment, VariableBase -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.segments import ( - ArrayFileSegment, - FileSegment, -) -from graphon.variables.types import SegmentType -from graphon.variables.utils import dumps_with_segments from sqlalchemy import Engine, delete, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -40,6 +27,19 @@ from core.workflow.variable_prefixes import ( from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable +from graphon.enums import NodeType +from graphon.file import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 601e9261fc..5fca444723 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -9,10 +9,6 @@ from collections.abc import Generator, Mapping, Sequence from dataclasses import dataclass from typing import Any -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import desc, select from sqlalchemy.orm import Session, sessionmaker @@ -26,6 +22,10 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index aab41efe50..a68f69a8d3 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,31 +5,6 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast -from graphon.entities import WorkflowNodeExecution -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import ( - ErrorStrategy, - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.errors import WorkflowNodeRunFailedError -from graphon.file import File -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.node_events import NodeRunResult -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission -from graphon.nodes.human_input.enums import HumanInputFormKind -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.start.entities import StartNodeData -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import load_into_variable_pool -from graphon.variables import VariableBase -from graphon.variables.input_entities import VariableEntityType -from graphon.variables.variables import Variable from sqlalchemy import and_, exists, or_, select from sqlalchemy.orm import Session, sessionmaker @@ -64,6 +39,31 @@ from events.app_event import app_draft_workflow_was_synced, app_published_workfl from extensions.ext_database import db from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings +from graphon.entities import WorkflowNodeExecution +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission +from graphon.nodes.human_input.enums import HumanInputFormKind +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import load_into_variable_pool +from graphon.variables import VariableBase +from graphon.variables.input_entities import VariableEntityType +from graphon.variables.variables import Variable from libs.datetime_utils import naive_utc_now from libs.helper import escape_like_pattern from models import Account diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 8f2f5f261e..c22e7e9918 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -7,7 +7,6 @@ from typing import Annotated, Any from celery import shared_task from flask import current_app, json -from graphon.runtime import GraphRuntimeState from pydantic import BaseModel, Discriminator, Field, Tag from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -23,6 +22,7 @@ from core.app.entities.app_invoke_entities import ( from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.runtime import GraphRuntimeState from libs.flask_utils import set_login_user from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 4db551c73c..beb23d8354 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -8,7 +8,6 @@ from typing import Any import click import pandas as pd from celery import shared_task -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func, select from core.db.session_factory import session_factory @@ -16,6 +15,7 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index ca73b4d374..fd743205a1 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -2,8 +2,6 @@ import logging from datetime import timedelta from celery import shared_task -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import or_, select from sqlalchemy.orm import sessionmaker @@ -11,6 +9,8 @@ from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from models.human_input import HumanInputForm from models.workflow import WorkflowPause, WorkflowRun diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index a316eec7b9..f8ae3f4b6e 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -6,7 +6,6 @@ from typing import Any import click from celery import shared_task -from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -15,6 +14,7 @@ from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import GraphRuntimeState, VariablePool from models.human_input import ( DeliveryMethodType, HumanInputDelivery, diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index b9f382eccf..b0cbc54db3 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,7 +12,6 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -29,6 +28,7 @@ from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from enums.quota_type import QuotaType +from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, CreatorUserRole, diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index b4f975f4da..5ca04fd7c2 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -10,11 +10,11 @@ import logging from typing import Any from celery import shared_task -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 128cdd72e1..0d5475a56d 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -10,13 +10,13 @@ import logging from typing import Any from celery import shared_task +from sqlalchemy import select + +from core.db.session_factory import session_factory from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter -from sqlalchemy import select - -from core.db.session_factory import session_factory from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index 91245e879e..a876b0c4aa 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -1,9 +1,8 @@ from collections.abc import Generator -from graphon.node_events import StreamCompletedEvent - from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage +from graphon.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 3fdea10976..b5318aaa2b 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,8 +1,7 @@ -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamCompletedEvent - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent class _Seg: diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index ce04a158a8..c4146d5ccd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,6 +4,9 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient + # import monkeypatch from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.llm_entities import ( @@ -23,9 +26,6 @@ from graphon.model_runtime.entities.model_entities import ( ) from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient - class MockModelClass(PluginModelClient): def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index c7bb90f019..e130644338 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -3,10 +3,6 @@ import unittest import uuid import pytest -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType -from graphon.variables.variables import StringVariable from sqlalchemy import delete, func, select from sqlalchemy.orm import Session @@ -15,6 +11,10 @@ from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from libs import datetime_utils from models.enums import CreatorUserRole from models.model import UploadFile diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index 3dfedd811d..4f444598b1 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,11 @@ import uuid from unittest.mock import patch import pytest -from graphon.variables.segments import StringSegment from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -209,7 +209,6 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant from graphon.variables.types import SegmentType - from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -453,7 +452,6 @@ class TestDeleteDraftVariablesSessionCommit: def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" from graphon.variables.types import SegmentType - from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index c0143faa85..a9a2617bae 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -1,12 +1,11 @@ from unittest.mock import MagicMock -from graphon.model_runtime.entities.model_entities import ModelType - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 4f41396c22..e3476c292b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -2,17 +2,17 @@ import time import uuid import pytest + +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import NodeRunResult from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.limits import CodeNodeLimits from graphon.runtime import GraphRuntimeState, VariablePool - -from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.code_executor",) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index b1f937e738..aa6cf1e021 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,11 +3,6 @@ import uuid from urllib.parse import urlencode import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -16,6 +11,11 @@ from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.http",) @@ -192,6 +192,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" + from core.workflow.system_variables import build_system_variables from graphon.enums import BuiltinNodeTypes from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -202,8 +203,6 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool - from core.workflow.system_variables import build_system_variables - # Create variable pool variable_pool = VariablePool( system_variables=build_system_variables(user_id="test", files=[]), diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index f0f3fcead1..fa5d63cfbf 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,6 +4,11 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.llm_generator.output_parser.structured_output import _parse_structured_output +from core.model_manager import ModelInstance +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent from graphon.nodes.llm.file_saver import LLMFileSaver @@ -12,12 +17,6 @@ from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.nodes.protocols import HttpClientProtocol from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.llm_generator.output_parser.structured_output import _parse_structured_output -from core.model_manager import ModelInstance -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index fe512c2585..52886855b8 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,17 +3,16 @@ import time import uuid from unittest.mock import MagicMock -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 2d728569be..9e3e1a47e3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,15 +1,14 @@ import time import uuid +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index ea95959a82..5a22f81a69 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -4,11 +4,11 @@ import json import uuid from flask.testing import FlaskClient -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index b4b65abdb6..c342e8994b 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -22,13 +22,6 @@ import uuid from time import time import pytest -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events import GraphRunPausedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session @@ -40,6 +33,13 @@ from core.app.layers.pause_state_persist_layer import ( ) from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 3b1570a9a8..14d5740072 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -4,7 +4,6 @@ from __future__ import annotations from uuid import uuid4 -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from sqlalchemy import Engine, select from sqlalchemy.orm import Session @@ -18,6 +17,7 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from models.account import ( Account, AccountStatus, diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 3ecf621095..da4f8847d6 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -4,6 +4,17 @@ from datetime import timedelta from unittest.mock import MagicMock import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowType from graphon.graph import Graph from graphon.graph_engine import GraphEngine @@ -16,17 +27,6 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool -from sqlalchemy import delete, select -from sqlalchemy.orm import Session - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from models import Account from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index cc72dc1cf3..2e207ddc67 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -4,13 +4,13 @@ from unittest.mock import patch from uuid import uuid4 import pytest -from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index a68b3a08c7..641399c7f9 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -5,10 +5,10 @@ from __future__ import annotations from datetime import timedelta from uuid import uuid4 -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 64c93ac07c..aebe87839c 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -8,15 +8,15 @@ from unittest.mock import Mock from uuid import uuid4 import pytest +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_storage import storage from graphon.entities import WorkflowExecution from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus -from sqlalchemy import Engine, delete, select -from sqlalchemy.orm import Session, sessionmaker - -from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 4f3c0e4200..00a2f9a59f 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -842,7 +842,6 @@ class TestAgentService: conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) from graphon.file import FileTransferMethod, FileType - from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 6c15587058..77ce28b999 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -9,7 +9,6 @@ from uuid import uuid4 import pytest import yaml from faker import Faker -from graphon.enums import BuiltinNodeTypes from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, @@ -17,6 +16,7 @@ from core.trigger.constants import ( TRIGGER_WEBHOOK_NODE_TYPE, ) from extensions.ext_redis import redis_client +from graphon.enums import BuiltinNodeTypes from models import Account, AppMode from models.model import AppModelConfig, IconType from services import app_dsl_service diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index f9bfa570cb..0de3c64c4f 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,11 +9,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index 2974e00897..ac0483a45d 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -3,10 +3,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType +from graphon.model_runtime.entities.model_entities import ModelType from models.account import ( Account, AccountStatus, diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index c8f04e9215..fe426ae516 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -5,9 +5,9 @@ Testcontainers integration tests for archived workflow run deletion service. from datetime import UTC, datetime, timedelta from uuid import uuid4 -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import select +from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index c46b8fba0b..18c5320d0a 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -3,8 +3,6 @@ import uuid from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData from core.workflow.human_input_compat import ( EmailDeliveryConfig, @@ -12,6 +10,8 @@ from core.workflow.human_input_compat import ( EmailRecipients, ExternalRecipient, ) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index ba926bf675..8955a3b5f2 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -405,11 +405,10 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock models + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.provider_entities import ProviderEntity - from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity - # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( ProviderEntity( @@ -644,9 +643,8 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock default model response - from graphon.model_runtime.entities.common_entities import I18nObject - from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity + from graphon.model_runtime.entities.common_entities import I18nObject mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 749c6fff5b..1e57b5603d 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -8,9 +8,9 @@ from unittest.mock import patch import pytest from faker import Faker -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session +from graphon.enums import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLogCreatedFrom diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 0c281c8c33..86cf2327c7 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,9 +1,9 @@ import pytest from faker import Faker -from graphon.variables.segments import StringSegment from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from graphon.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 7c43bf676b..4dab895135 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta from uuid import uuid4 -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 2fb62e0fc0..fa3ac12cf0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -11,7 +11,8 @@ from unittest.mock import Mock, patch import pytest from faker import Faker -from sqlalchemy import func, select +from sqlalchemy import ColumnElement, func, select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -21,6 +22,14 @@ from tasks.clean_notion_document_task import clean_notion_document_task from tests.test_containers_integration_tests.helpers import generate_valid_password +def _count_documents(session: Session, condition: ColumnElement[bool]) -> int: + return session.scalar(select(func.count()).select_from(Document).where(condition)) or 0 + + +def _count_segments(session: Session, condition: ColumnElement[bool]) -> int: + return session.scalar(select(func.count()).select_from(DocumentSegment).where(condition)) or 0 + + class TestCleanNotionDocumentTask: """Integration tests for clean_notion_document_task using testcontainers.""" @@ -146,29 +155,14 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(Document).where(Document.id.in_(document_ids)) - ) - == 3 - ) - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) - ) - == 6 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(document_ids)) == 3 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 6 # Execute cleanup task clean_notion_document_task(document_ids, dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) - ) - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 0 # Verify index processor was called mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value @@ -328,12 +322,7 @@ class TestCleanNotionDocumentTask: # The task properly handles various index types and document configurations. # Verify segments are deleted - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) - ) - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Reset mock for next iteration mock_index_processor_factory.reset_mock() @@ -416,12 +405,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) - ) - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies that segments without index_node_ids # are properly deleted from the database. @@ -507,18 +491,8 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id) - ) - == 5 - ) - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) - ) - == 10 - ) + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == 5 + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 10 # Clean up only first 3 documents documents_to_clean = [doc.id for doc in documents[:3]] @@ -528,29 +502,12 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(documents_to_clean, dataset.id) # Verify only specified documents' segments are deleted - assert ( - db_session_with_containers.scalar( - select(func.count()) - .select_from(DocumentSegment) - .where(DocumentSegment.document_id.in_(documents_to_clean)) - ) - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(documents_to_clean)) == 0 # Verify remaining documents and segments are intact remaining_docs = [doc.id for doc in documents[3:]] - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs)) - ) - == 2 - ) - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs)) - ) - == 4 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 4 # Note: This test successfully verifies partial document cleanup operations. # The database operations work correctly, isolating only the specified documents. @@ -634,23 +591,13 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all segments exist before cleanup - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) - ) - == 4 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 4 # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify all segments are deleted regardless of status - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) - ) - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies database operations. # IndexProcessor verification would require more sophisticated mocking. @@ -820,16 +767,9 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == num_documents assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id) - ) - == num_documents - ) - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) - ) + _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == num_documents * num_segments_per_doc ) @@ -838,12 +778,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) - ) - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 # Note: This test successfully verifies bulk document cleanup operations. # The database efficiently handles large-scale deletions. @@ -950,29 +885,12 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([target_document.id], target_dataset.id) # Verify only documents' segments from target dataset are deleted - assert ( - db_session_with_containers.scalar( - select(func.count()) - .select_from(DocumentSegment) - .where(DocumentSegment.document_id == target_document.id) - ) - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == target_document.id) == 0 # Verify documents from other datasets remain intact remaining_docs = [doc.id for doc in all_documents[1:]] - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs)) - ) - == 2 - ) - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs)) - ) - == 6 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 6 # Note: This test successfully verifies multi-tenant isolation. # Only documents from the target dataset are affected, maintaining tenant separation. @@ -1067,13 +985,9 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert db_session_with_containers.scalar( - select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id) - ) == len(document_statuses) + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == len(document_statuses) assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) - ) + _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == len(document_statuses) * 2 ) @@ -1082,12 +996,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted regardless of status - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) - ) - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 # Note: This test successfully verifies cleanup of documents in various states. # All documents are deleted regardless of their indexing status. @@ -1185,29 +1094,14 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(Document).where(Document.id == document.id) - ) - == 1 - ) - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) - ) - == 3 - ) + assert _count_documents(db_session_with_containers, Document.id == document.id) == 1 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 3 # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.scalar( - select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) - ) - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies cleanup of documents with rich metadata. # The task properly handles complex document structures and metadata fields. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 1b4dcf28ea..328bdbf055 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -3,9 +3,6 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import delete from configs import dify_config @@ -21,6 +18,9 @@ from core.workflow.human_input_compat import ( MemberRecipient, ) from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index b5bef145d5..b43b622870 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,12 +2,12 @@ import uuid from unittest.mock import ANY, call, patch import pytest -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 6e98c0855a..b00d827e37 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -24,12 +24,12 @@ from dataclasses import dataclass from datetime import timedelta import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 7c4553d4a0..9c20118e27 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -10,7 +10,6 @@ from typing import Any import pytest from flask import Flask, Response from flask.testing import FlaskClient -from graphon.enums import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session @@ -25,6 +24,7 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index e11102acb1..c4a8148446 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -6,14 +6,14 @@ from unittest.mock import Mock import pytest from flask import Flask -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index 740da1f1df..b19a1740eb 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock, patch import pytest from flask_restx import marshal -from graphon.variables.types import SegmentType from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -16,6 +15,7 @@ from controllers.console.app.workflow_draft_variable import ( ) from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index 9c42ee9529..b2f949c6e2 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -11,9 +11,10 @@ from unittest.mock import MagicMock import pytest from flask import Flask from flask.views import MethodView +from werkzeug.exceptions import Forbidden + from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from werkzeug.exceptions import Forbidden if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index a26fea8fbd..c16ebad739 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,7 +13,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -30,6 +29,7 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 57681d8f5b..3364c07e62 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,7 +16,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -35,6 +34,7 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index 97fdf1a011..14c35a9ed5 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -20,7 +20,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.variables.types import SegmentType from werkzeug.exceptions import BadRequest, NotFound import services @@ -38,6 +37,7 @@ from controllers.service_api.app.conversation import ( ConversationVariableUpdatePayload, ) from controllers.service_api.app.error import NotChatAppError +from graphon.variables.types import SegmentType from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService from services.errors.conversation import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 74a3c75839..da09ec13ce 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -20,7 +20,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.enums import WorkflowExecutionStatus from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -37,6 +36,7 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from graphon.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index 4b8e3a738c..eda270258d 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,8 +1,7 @@ from types import SimpleNamespace -from graphon.enums import WorkflowExecutionStatus - from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField +from graphon.enums import WorkflowExecutionStatus def test_workflow_run_status_field_with_enum() -> None: diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 8bde9c1f97..11b53dd0f9 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,8 +1,7 @@ +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from graphon.file import FileTransferMethod, FileUploadConfig, ImageConfig from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager - def test_convert_with_vision(): config = { diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 1fb0dc6cf1..45d4b0e321 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 -from graphon.variables import SegmentType from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from factories import variable_factory +from graphon.variables import SegmentType from models import ConversationVariable, Workflow MINIMAL_GRAPH = { diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index f255d2c7df..b3ea1a464f 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index 4a94a2b4f1..201923e0e4 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -1,11 +1,11 @@ from types import SimpleNamespace import pytest -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.workflow.system_variables import build_system_variables from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.runtime import GraphRuntimeState, VariablePool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 328cd12f12..3ab63aed25 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,10 +1,9 @@ from collections.abc import Mapping, Sequence +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from graphon.variables.segments import ArrayFileSegment, FileSegment -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter - class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test class for WorkflowResponseConverter._fetch_files_from_variable_value method""" diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index bc11bf4174..1bef6f69cd 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -1,13 +1,12 @@ from datetime import UTC, datetime from types import SimpleNamespace -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter(): diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index c9e146ff12..936ac37e55 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -1,11 +1,10 @@ from types import SimpleNamespace -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter() -> WorkflowResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 0fde7565d2..b3c0eb74fa 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -10,8 +10,6 @@ from typing import Any from unittest.mock import Mock import pytest -from graphon.entities import WorkflowStartReason -from graphon.enums import BuiltinNodeTypes from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -27,6 +25,8 @@ from core.app.entities.queue_entities import ( QueueNodeSucceededEvent, ) from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index 6167be3bbd..b0f8b423e1 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,7 +1,7 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.apps.base_app_generator import BaseAppGenerator +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -476,9 +476,8 @@ class TestBaseAppGeneratorExtras: assert converted[1] == "event: ping\n\n" def test_get_draft_var_saver_factory_debugger(self): - from graphon.enums import BuiltinNodeTypes - from core.app.entities.app_invoke_entities import InvokeFrom + from graphon.enums import BuiltinNodeTypes from models import Account base_app_generator = BaseAppGenerator() diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index aa789d9ff3..10fb2271f4 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -1,11 +1,11 @@ from unittest.mock import MagicMock import pytest -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 9e30faecf2..620a153204 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -4,14 +4,14 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 8a717e1dcc..a3ab379b66 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -3,11 +3,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from core.app.apps.common import workflow_response_converter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -16,6 +11,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.account import Account from models.human_input import RecipientType diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 29df903aa8..1f6e7e12ef 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -2,15 +2,14 @@ import time from contextlib import contextmanager from unittest.mock import MagicMock -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState - from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index a78c1b428f..ba55e8f695 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,6 +1,9 @@ from collections.abc import Sequence from unittest.mock import Mock +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.command_channels import CommandChannel from graphon.graph_events import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent @@ -8,10 +11,6 @@ from graphon.node_events import NodeRunResult from graphon.runtime import ReadOnlyGraphRuntimeState from graphon.variables import StringVariable from graphon.variables.segments import Segment, StringSegment - -from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 035e64325b..539944d683 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,6 +4,16 @@ from time import time from unittest.mock import Mock import pytest + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import ( + PauseStatePersistenceLayer, + WorkflowResumptionContext, + _AdvancedChatAppGenerateEntityWrapper, + _WorkflowGenerateEntityWrapper, +) +from core.workflow.system_variables import SystemVariableKey from graphon.entities.pause_reason import SchedulingPause from graphon.graph_engine.entities.commands import GraphEngineCommand from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError @@ -15,16 +25,6 @@ from graphon.graph_events import ( ) from graphon.runtime import ReadOnlyVariablePool from graphon.variables.segments import Segment - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import ( - PauseStatePersistenceLayer, - WorkflowResumptionContext, - _AdvancedChatAppGenerateEntityWrapper, - _WorkflowGenerateEntityWrapper, -) -from core.workflow.system_variables import SystemVariableKey from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 4aaa10a81a..1c1bf391d3 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -2,8 +2,6 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity @@ -28,6 +26,8 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index d338cadb77..81315d2508 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -2,15 +2,15 @@ import types from collections.abc import Generator import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError from core.workflow.file_reference import parse_file_reference +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index fe2c226843..a28143026f 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -345,22 +345,26 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) ] ) - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="encrypted-old-key" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): - with patch( - "core.entities.provider_configuration.encrypter.encrypt_token", - side_effect=lambda tenant_id, value: f"enc::{value}", - ): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, - credential_id="credential-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): + with patch( + "core.entities.provider_configuration.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc::{value}", + ): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, + credential_id="credential-1", + ) assert validated["openai_api_key"] == "enc::restored-key" assert validated["region"] == "us" @@ -370,23 +374,15 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) -def test_validate_provider_credentials_opens_session_when_not_passed() -> None: +def test_validate_provider_credentials_without_credential_id() -> None: configuration = _build_provider_configuration() - mock_session = Mock() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"region": "us"} - with patch("core.entities.provider_configuration.Session") as mock_session_cls: - with patch("core.entities.provider_configuration.db") as mock_db: - mock_db.engine = Mock() - mock_session_cls.return_value.__enter__.return_value = mock_session - with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory - ): - validated = configuration.validate_provider_credentials(credentials={"region": "us"}) + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} - mock_session_cls.assert_called_once() def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None: @@ -717,18 +713,22 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non def test_validate_provider_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-key"} @@ -1060,37 +1060,35 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback( def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( encrypted_config='{"openai_api_key":"enc"}' ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) - assert validated == {"openai_api_key": "enc-new"} - - session = Mock() - mock_factory = Mock() - mock_factory.model_credentials_validate.return_value = {"region": "us"} - with _patched_session(session): + with _patched_session(mock_session): with patch( "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory ): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"region": "us"}, - ) + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) + assert validated == {"openai_api_key": "enc-new"} + + mock_factory2 = Mock() + mock_factory2.model_credentials_validate.return_value = {"region": "us"} + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"region": "us"}, + ) assert validated == {"region": "us"} @@ -1570,18 +1568,20 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = None + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = None mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} @@ -1692,20 +1692,24 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: def test_validate_custom_model_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index 9a815fb94d..57456085c3 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -3,7 +3,6 @@ from unittest.mock import Mock, patch import jsonschema import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types @@ -19,6 +18,7 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py index 2cbff54c42..69650c85cc 100644 --- a/api/tests/unit_tests/core/ops/test_config_entity.py +++ b/api/tests/unit_tests/core/ops/test_config_entity.py @@ -1,16 +1,11 @@ -import pytest -from pydantic import ValidationError +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_opik.config import OpikConfig +from dify_trace_weave.config import WeaveConfig -from core.ops.entities.config_entity import ( - AliyunConfig, - ArizeConfig, - LangfuseConfig, - LangSmithConfig, - OpikConfig, - PhoenixConfig, - TracingProviderEnum, - WeaveConfig, -) +from core.ops.entities.config_entity import TracingProviderEnum class TestTracingProviderEnum: @@ -27,349 +22,8 @@ class TestTracingProviderEnum: assert TracingProviderEnum.ALIYUN == "aliyun" -class TestArizeConfig: - """Test cases for ArizeConfig""" - - def test_valid_config(self): - """Test valid Arize configuration""" - config = ArizeConfig( - api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" - ) - assert config.api_key == "test_key" - assert config.space_id == "test_space" - assert config.project == "test_project" - assert config.endpoint == "https://custom.arize.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = ArizeConfig() - assert config.api_key is None - assert config.space_id is None - assert config.project is None - assert config.endpoint == "https://otlp.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = ArizeConfig(project="") - assert config.project == "default" - - def test_project_validation_none(self): - """Test project validation with None value""" - config = ArizeConfig(project=None) - assert config.project == "default" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = ArizeConfig(endpoint="") - assert config.endpoint == "https://otlp.arize.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation normalizes URL by removing path""" - config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") - assert config.endpoint == "https://custom.arize.com" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="ftp://invalid.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="invalid.com") - - -class TestPhoenixConfig: - """Test cases for PhoenixConfig""" - - def test_valid_config(self): - """Test valid Phoenix configuration""" - config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.phoenix.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = PhoenixConfig() - assert config.api_key is None - assert config.project is None - assert config.endpoint == "https://app.phoenix.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = PhoenixConfig(project="") - assert config.project == "default" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation with path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") - assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" - - def test_endpoint_validation_without_path(self): - """Test endpoint validation without path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") - assert config.endpoint == "https://app.phoenix.arize.com" - - -class TestLangfuseConfig: - """Test cases for LangfuseConfig""" - - def test_valid_config(self): - """Test valid Langfuse configuration""" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == "https://custom.langfuse.com" - - def test_valid_config_with_path(self): - host = "https://custom.langfuse.com/api/v1" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == host - - def test_default_values(self): - """Test default values are set correctly""" - config = LangfuseConfig(public_key="public", secret_key="secret") - assert config.host == "https://api.langfuse.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangfuseConfig() - - with pytest.raises(ValidationError): - LangfuseConfig(public_key="public") - - with pytest.raises(ValidationError): - LangfuseConfig(secret_key="secret") - - def test_host_validation_empty(self): - """Test host validation with empty value""" - config = LangfuseConfig(public_key="public", secret_key="secret", host="") - assert config.host == "https://api.langfuse.com" - - -class TestLangSmithConfig: - """Test cases for LangSmithConfig""" - - def test_valid_config(self): - """Test valid LangSmith configuration""" - config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.smith.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = LangSmithConfig(api_key="key", project="project") - assert config.endpoint == "https://api.smith.langchain.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangSmithConfig() - - with pytest.raises(ValidationError): - LangSmithConfig(api_key="key") - - with pytest.raises(ValidationError): - LangSmithConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") - - -class TestOpikConfig: - """Test cases for OpikConfig""" - - def test_valid_config(self): - """Test valid Opik configuration""" - config = OpikConfig( - api_key="test_key", - project="test_project", - workspace="test_workspace", - url="https://custom.comet.com/opik/api/", - ) - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.workspace == "test_workspace" - assert config.url == "https://custom.comet.com/opik/api/" - - def test_default_values(self): - """Test default values are set correctly""" - config = OpikConfig() - assert config.api_key is None - assert config.project is None - assert config.workspace is None - assert config.url == "https://www.comet.com/opik/api/" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = OpikConfig(project="") - assert config.project == "Default Project" - - def test_url_validation_empty(self): - """Test URL validation with empty value""" - config = OpikConfig(url="") - assert config.url == "https://www.comet.com/opik/api/" - - def test_url_validation_missing_suffix(self): - """Test URL validation requires /api/ suffix""" - with pytest.raises(ValidationError, match="URL should end with /api/"): - OpikConfig(url="https://custom.comet.com/opik/") - - def test_url_validation_invalid_scheme(self): - """Test URL validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - OpikConfig(url="ftp://custom.comet.com/opik/api/") - - -class TestWeaveConfig: - """Test cases for WeaveConfig""" - - def test_valid_config(self): - """Test valid Weave configuration""" - config = WeaveConfig( - api_key="test_key", - entity="test_entity", - project="test_project", - endpoint="https://custom.wandb.ai", - host="https://custom.host.com", - ) - assert config.api_key == "test_key" - assert config.entity == "test_entity" - assert config.project == "test_project" - assert config.endpoint == "https://custom.wandb.ai" - assert config.host == "https://custom.host.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = WeaveConfig(api_key="key", project="project") - assert config.entity is None - assert config.endpoint == "https://trace.wandb.ai" - assert config.host is None - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - WeaveConfig() - - with pytest.raises(ValidationError): - WeaveConfig(api_key="key") - - with pytest.raises(ValidationError): - WeaveConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") - - def test_host_validation_optional(self): - """Test host validation is optional but validates when provided""" - config = WeaveConfig(api_key="key", project="project", host=None) - assert config.host is None - - config = WeaveConfig(api_key="key", project="project", host="") - assert config.host == "" - - config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") - assert config.host == "https://valid.host.com" - - def test_host_validation_invalid_scheme(self): - """Test host validation rejects invalid schemes when provided""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") - - -class TestAliyunConfig: - """Test cases for AliyunConfig""" - - def test_valid_config(self): - """Test valid Aliyun configuration""" - config = AliyunConfig( - app_name="test_app", - license_key="test_license_key", - endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", - ) - assert config.app_name == "test_app" - assert config.license_key == "test_license_key" - assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - assert config.app_name == "dify_app" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - AliyunConfig() - - with pytest.raises(ValidationError): - AliyunConfig(license_key="test_license") - - with pytest.raises(ValidationError): - AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_app_name_validation_empty(self): - """Test app_name validation with empty value""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" - ) - assert config.app_name == "dify_app" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = AliyunConfig(license_key="test_license", endpoint="") - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation preserves path for Aliyun endpoints""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - ) - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_license_key_required(self): - """Test that license_key is required and cannot be empty""" - with pytest.raises(ValidationError): - AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_valid_endpoint_format_examples(self): - """Test valid endpoint format examples from comments""" - valid_endpoints = [ - # cms2.0 public endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", - # cms2.0 intranet endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", - # xtrace public endpoint - "http://tracing-cn-heyuan.arms.aliyuncs.com", - # xtrace intranet endpoint - "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", - ] - - for endpoint in valid_endpoints: - config = AliyunConfig(license_key="test_license", endpoint=endpoint) - assert config.endpoint == endpoint - - class TestConfigIntegration: - """Integration tests for configuration classes""" + """Cross-provider configuration sanity checks""" def test_all_configs_can_be_instantiated(self): """Test that all config classes can be instantiated with valid data""" @@ -388,7 +42,6 @@ class TestConfigIntegration: def test_url_normalization_consistency(self): """Test that URL normalization works consistently across configs""" - # Test that paths are removed from endpoints arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test") phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index a3b1e5f6b0..704b82adc0 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -17,14 +17,6 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from core.plugin.entities.plugin_daemon import ( @@ -45,6 +37,14 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index 90730dff5a..d49b6e4b71 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -1,12 +1,12 @@ from collections.abc import Generator import pytest -from graphon.file import File, FileTransferMethod, FileType from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector +from graphon.file import File, FileTransferMethod, FileType class TestChunkMerger: diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 2b280dd674..395d392127 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,6 +2,13 @@ from typing import cast from unittest.mock import MagicMock, patch import pytest + +from configs import dify_config +from core.app.app_config.entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -11,13 +18,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) - -from configs import dify_config -from core.app.app_config.entities import ModelConfigEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 4a54649b28..803afa54d7 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -1,19 +1,18 @@ from unittest.mock import MagicMock -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.entities.provider_configuration import ProviderModelBundle from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index a4b3960b0a..5d865d934c 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,3 +1,5 @@ +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -7,9 +9,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil - def test_build_prompt_message_with_prompt_message_contents(): prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 408cf14a51..4b8175b0b4 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -49,6 +49,10 @@ from unittest.mock import Mock, patch import numpy as np import pytest +from sqlalchemy.exc import IntegrityError + +from core.entities.embedding_type import EmbeddingInputType +from core.rag.embedding.cached_embedding import CacheEmbedding from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from graphon.model_runtime.errors.invoke import ( @@ -56,10 +60,6 @@ from graphon.model_runtime.errors.invoke import ( InvokeConnectionError, InvokeRateLimitError, ) -from sqlalchemy.exc import IntegrityError - -from core.entities.embedding_type import EmbeddingInputType -from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 641c5d9ba0..7c4defc180 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -53,7 +53,6 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm.exc import ObjectDeletedError from core.errors.error import ProviderTokenNotInitError @@ -64,6 +63,7 @@ from core.indexing_runner import ( ) from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document +from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e229d5fc1a..3d3322094e 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -9,10 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowType from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 7dbf78d0f0..05b4f3a053 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -9,14 +9,14 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) from graphon.enums import BuiltinNodeTypes - -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 0fc82dda53..8be1ac318c 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,11 +7,6 @@ from datetime import datetime from types import SimpleNamespace import pytest -from graphon.nodes.human_input.entities import ( - FormDefinition, - UserAction, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( HumanInputFormRecord, @@ -26,6 +21,11 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) +from graphon.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index 84fe522388..abdbc72085 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 27729e7f06..5af1376a0a 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -11,17 +11,17 @@ from datetime import UTC, datetime from typing import Any from unittest.mock import MagicMock -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from sqlalchemy import Engine from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index ac65d0c02b..f17927f16b 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,7 +1,6 @@ import json from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig - from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index f5efb78b61..afea9144c0 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock, patch import pytest import redis -from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType @pytest.fixture diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 331166fe63..b19a21d7f4 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -1,15 +1,6 @@ from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormOption, - FormType, - ProviderEntity, -) from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus from core.entities.provider_entities import ( @@ -21,6 +12,15 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormOption, + FormType, + ProviderEntity, +) from models.provider import Provider, ProviderType diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index 0e3a7e623a..43f3fbd5c9 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,9 +1,9 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index c20edd7400..72a73dd936 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -11,7 +11,6 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -25,6 +24,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType class StubScalars: diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 7406b88270..72052c8c05 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -2,6 +2,11 @@ import dataclasses import orjson import pytest +from pydantic import BaseModel + +from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables.segment_group import SegmentGroup @@ -42,11 +47,6 @@ from graphon.variables.variables import ( StringVariable, Variable, ) -from pydantic import BaseModel - -from core.helper import encrypter -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool def _build_variable_pool( diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index 37ecd2890b..d4e862220a 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,4 +1,5 @@ import pytest + from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import StringSegment from graphon.variables.types import ArrayValidation, SegmentType diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 09254e17a3..94e788edb2 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from typing import Any import pytest + from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 75b01bf42e..dae5e1ce98 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,4 +1,6 @@ import pytest +from pydantic import ValidationError + from graphon.variables import ( ArrayFileVariable, ArrayVariable, @@ -10,7 +12,6 @@ from graphon.variables import ( StringVariable, ) from graphon.variables.variables import VariableBase -from pydantic import ValidationError def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 41627f5e0b..025d79b25d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -5,12 +5,13 @@ Shared fixtures for ObservabilityLayer tests. from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider +from graphon.enums import BuiltinNodeTypes + @pytest.fixture def memory_span_exporter(): @@ -61,9 +62,8 @@ def mock_llm_node(): @pytest.fixture def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" - from graphon.nodes.tool.entities import ToolNodeData - from core.tools.entities.tool_entities import ToolProviderType + from graphon.nodes.tool.entities import ToolNodeData node = MagicMock() node.id = "test-tool-node-id" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 9cf72763ee..919f15efd0 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -13,10 +13,10 @@ Test coverage: from unittest.mock import patch import pytest -from graphon.enums import BuiltinNodeTypes from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer +from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 88989db856..76b2984a4b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -7,12 +7,11 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any +from core.workflow.node_factory import DifyNodeFactory from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node -from core.workflow.node_factory import DifyNodeFactory - from .test_mock_nodes import ( MockAgentNode, MockCodeNode, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 8b7fbd1b30..971b9b2bbf 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,6 +10,10 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock +from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.nodes.agent import AgentNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -27,11 +31,6 @@ from graphon.nodes.template_transform import TemplateTransformNode from graphon.nodes.tool import ToolNode from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError -from core.model_manager import ModelInstance -from core.workflow.node_runtime import DifyToolNodeRuntime -from core.workflow.nodes.agent import AgentNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode - if TYPE_CHECKING: from graphon.entities import GraphInitParams from graphon.runtime import GraphRuntimeState diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 8311a1e847..55a329eba9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,6 +4,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig @@ -23,14 +30,6 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index b11f957677..7d23b63049 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,6 +19,11 @@ from functools import lru_cache from pathlib import Path from typing import Any +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.tools.utils.yaml_utils import _load_yaml_file +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from graphon.entities import GraphInitParams from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig @@ -39,12 +44,6 @@ from graphon.variables import ( StringVariable, ) -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool - from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 7195471eb6..9c0ad25b58 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,15 +2,14 @@ import time import uuid from unittest.mock import MagicMock -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.answer.answer_node import AnswerNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 343bcd3919..ec4cef1955 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,10 +1,10 @@ import pytest + +from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node -from core.workflow.node_factory import get_node_type_classes_mapping - # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index b9371a34f4..ef0df55995 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,6 +1,7 @@ import types from collections.abc import Mapping +from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node @@ -13,8 +14,6 @@ from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) -from core.workflow.node_factory import get_node_type_classes_mapping - def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index d155124c50..ce0c9b79c6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,3 +1,4 @@ +from configs import dify_config from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.entities import CodeLanguage, CodeNodeData from graphon.nodes.code.exc import ( @@ -8,8 +9,6 @@ from graphon.nodes.code.exc import ( from graphon.nodes.code.limits import CodeNodeLimits from graphon.variables.types import SegmentType -from configs import dify_config - CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index fb03ae9998..9cceadde49 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,8 +1,7 @@ -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index a5026b40cf..be7cc073db 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,4 +1,8 @@ import pytest + +from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables from graphon.file.file_manager import file_manager from graphon.nodes.http_request import ( BodyData, @@ -12,10 +16,6 @@ from graphon.nodes.http_request.exc import AuthorizationConfigError from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool -from configs import dify_config -from core.helper.ssrf_proxy import ssrf_proxy -from core.workflow.system_variables import default_system_variables - HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 4705b3f76e..a3cadc0681 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -3,17 +3,17 @@ from typing import Any import httpx import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index d16e1233ac..1d6a4da7c4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,6 +1,5 @@ -from graphon.runtime import VariablePool - from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients +from graphon.runtime import VariablePool def test_render_body_template_replaces_variable_values(): diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index 52802c7ce1..bc98028d5b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,6 +1,9 @@ import datetime from types import SimpleNamespace +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -11,10 +14,6 @@ from graphon.graph_events import ( from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import default_system_variables from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index ab64be59ad..45e8ae7d20 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -3,10 +3,6 @@ import uuid from unittest.mock import Mock import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import ( @@ -21,6 +17,10 @@ from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index fdf1706765..eca34f05be 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,14 +1,14 @@ from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.nodes.list_operator.node import ListOperatorNode from graphon.runtime import GraphRuntimeState from graphon.variables import ArrayNumberSegment, ArrayStringSegment -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY - class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 7841bf05ad..b1f81b6c48 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,6 +5,19 @@ from collections.abc import Sequence from unittest import mock import pytest + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.common_entities import I18nObject @@ -67,19 +80,6 @@ from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment - -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import ( - DifyCredentialsProvider, - DifyModelFactory, - build_dify_model_access, - fetch_model_config, -) -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.system_variables import default_system_variables from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 1c362a0a03..8f8ec49f14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Any import pytest + +from factories.variable_factory import build_segment_with_type from graphon.model_runtime.entities import LLMMode from graphon.nodes.llm import ModelConfig, VisionConfig from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData @@ -18,8 +20,6 @@ from graphon.nodes.parameter_extractor.exc import ( from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from graphon.variables.types import SegmentType -from factories.variable_factory import build_segment_with_type - @dataclass class ValidTestCase: diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index e11ebf6eb8..0522dd9d14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -1,16 +1,16 @@ from collections.abc import Mapping import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_runtime import resolve_dify_run_context +from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.entities.base_node_data import BaseNodeData from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_runtime import resolve_dify_run_context -from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 555ff0c945..87ec2d5bce 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -4,6 +4,8 @@ from unittest.mock import Mock, patch import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod @@ -19,8 +21,6 @@ from graphon.nodes.document_extractor.node import ( from graphon.variables import ArrayFileSegment from graphon.variables.segments import ArrayStringSegment from graphon.variables.variables import StringVariable - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 1b14f0ab13..782750e02e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -3,6 +3,11 @@ import uuid from unittest.mock import MagicMock, Mock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph @@ -11,11 +16,6 @@ from graphon.nodes.if_else.if_else_node import IfElseNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition from graphon.variables import ArrayFileSegment - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index d28c3e01e5..b217e4e8e7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.nodes.list_operator.entities import ( @@ -16,8 +18,6 @@ from graphon.nodes.list_operator.exc import InvalidKeyError from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func from graphon.variables import ArrayFileSegment -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom - @pytest.fixture def list_operator_node(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 833c303052..543f9878de 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -2,16 +2,16 @@ import json import time import pytest +from pydantic import ValidationError as PydanticValidationError + +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState from graphon.variables import build_segment, segment_to_variable from graphon.variables.input_entities import VariableEntity, VariableEntityType from graphon.variables.variables import Variable -from pydantic import ValidationError as PydanticValidationError - -from core.workflow.system_variables import build_system_variables -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index f1132af02b..617554ee17 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,5 +1,4 @@ import pytest -from graphon.entities.exc import BaseNodeError from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, @@ -7,6 +6,7 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) +from graphon.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index e7b2b2914a..dddd6eb00c 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -2,6 +2,15 @@ import uuid from collections import defaultdict import pytest + +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) +from factories.variable_factory import build_segment, segment_to_variable from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables import FileSegment, StringSegment @@ -27,15 +36,6 @@ from graphon.variables.variables import ( Variable, ) -from core.workflow.system_variables import build_system_variables, system_variables_to_mapping -from core.workflow.variable_pool_initializer import add_variables_to_pool -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) -from factories.variable_factory import build_segment, segment_to_variable - @pytest.fixture def pool(): diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index d8361d06c4..041c5cc612 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,12 +1,6 @@ from types import SimpleNamespace import pytest -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.file import File, FileTransferMethod, FileType -from graphon.nodes.code.code_node import CodeNode -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.runtime import VariablePool -from graphon.variables.variables import StringVariable from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage @@ -16,6 +10,12 @@ from core.workflow.variable_prefixes import ( ENVIRONMENT_VARIABLE_NODE_ID, ) from core.workflow.workflow_entry import WorkflowEntry +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import VariablePool +from graphon.variables.variables import StringVariable @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py index 4b2f98aeff..80dc8927fa 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py @@ -2,11 +2,10 @@ from unittest.mock import MagicMock, patch -from graphon.graph_engine.command_channels import RedisChannel -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.workflow_entry import WorkflowEntry +from graphon.graph_engine.command_channels import RedisChannel +from graphon.runtime import GraphRuntimeState, VariablePool class TestWorkflowEntryRedisChannel: diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 4fe3f2cb28..efafc8aa79 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -2,15 +2,30 @@ import uuid from unittest.mock import MagicMock, patch import pytest -from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from httpx import Response from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope from core.workflow.file_reference import build_file_reference, parse_file_reference, resolve_file_record_id from factories.file_factory.builders import build_from_mapping as _build_from_mapping +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from models import ToolFile, UploadFile + +def _make_session_ctx_mock(scalar_return=None): + """Return a mock usable as the ``session_factory.create_session()`` context manager. + + Patch ``factories.file_factory.builders.session_factory`` and set + ``mock_sf.create_session.return_value = `` to intercept DB calls + without requiring a live Flask app or database engine. + """ + session = MagicMock() + session.__enter__.return_value = session + session.__exit__.return_value = False + session.scalar.return_value = scalar_return + return session + + # Test Data TEST_TENANT_ID = "test_tenant_id" TEST_UPLOAD_FILE_ID = str(uuid.uuid4()) @@ -49,8 +64,11 @@ def mock_upload_file(): mock.source_url = TEST_REMOTE_URL mock.size = 1024 mock.key = "test_key" - with patch("factories.file_factory.builders.db.session.scalar", return_value=mock, autospec=True) as m: - yield m + session = _make_session_ctx_mock(scalar_return=mock) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session + # yield session.scalar so callers can inspect call_args and mutate return_value + yield session.scalar @pytest.fixture @@ -63,7 +81,9 @@ def mock_tool_file(): mock.mimetype = "application/pdf" mock.original_url = "http://example.com/tool.pdf" mock.size = 2048 - with patch("factories.file_factory.builders.db.session.scalar", return_value=mock, autospec=True): + session = _make_session_ctx_mock(scalar_return=mock) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session yield mock @@ -231,7 +251,9 @@ def test_build_from_remote_url_without_strict_validation(mock_http_head): def test_tool_file_not_found(): """Test ToolFile not found in database.""" - with patch("factories.file_factory.builders.db.session.scalar", return_value=None, autospec=True): + session = _make_session_ctx_mock(scalar_return=None) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session mapping = tool_file_mapping() with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) @@ -239,7 +261,9 @@ def test_tool_file_not_found(): def test_local_file_not_found(): """Test UploadFile not found in database.""" - with patch("factories.file_factory.builders.db.session.scalar", return_value=None, autospec=True): + session = _make_session_ctx_mock(scalar_return=None) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session mapping = local_file_mapping() with pytest.raises(ValueError, match="Invalid upload file"): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) @@ -311,7 +335,9 @@ def test_tenant_mismatch(): mock_file.key = "test_key" # Mock the database query to return None (no file found for this tenant) - with patch("factories.file_factory.builders.db.session.scalar", return_value=None, autospec=True): + session = _make_session_ctx_mock(scalar_return=None) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session mapping = local_file_mapping() with pytest.raises(ValueError, match="Invalid upload file"): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) @@ -350,11 +376,13 @@ def test_build_from_mapping_scopes_tool_file_to_end_user(): invoke_from=InvokeFrom.WEB_APP, ) - with patch("factories.file_factory.builders.db.session.scalar", return_value=tool_file, autospec=True) as scalar: + session = _make_session_ctx_mock(scalar_return=tool_file) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session with bind_file_access_scope(scope): build_from_mapping(mapping=tool_file_mapping(), tenant_id=TEST_TENANT_ID) - stmt = scalar.call_args.args[0] + stmt = session.scalar.call_args.args[0] whereclause = str(stmt.whereclause) assert "tool_files.user_id" in whereclause diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index a06c42507d..c35e80a826 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -4,6 +4,11 @@ from typing import Any from uuid import uuid4 import pytest +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st + +from factories import variable_factory +from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type from graphon.file import File, FileTransferMethod, FileType from graphon.variables import ( ArrayNumberVariable, @@ -31,11 +36,6 @@ from graphon.variables.segments import ( StringSegment, ) from graphon.variables.types import SegmentType -from hypothesis import HealthCheck, given, settings -from hypothesis import strategies as st - -from factories import variable_factory -from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type def test_string_variable(): diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py index f1ce1a2c1c..fa2c02020b 100644 --- a/api/tests/unit_tests/libs/_human_input/test_form_service.py +++ b/api/tests/unit_tests/libs/_human_input/test_form_service.py @@ -5,6 +5,7 @@ Unit tests for FormService. from datetime import timedelta import pytest + from graphon.nodes.human_input.entities import ( FormInput, UserAction, @@ -13,7 +14,6 @@ from graphon.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) - from libs.datetime_utils import naive_utc_now from .support import ( diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py index 0babfbb315..866ee61b3e 100644 --- a/api/tests/unit_tests/libs/_human_input/test_models.py +++ b/api/tests/unit_tests/libs/_human_input/test_models.py @@ -5,6 +5,7 @@ Unit tests for human input form models. from datetime import datetime, timedelta import pytest + from graphon.nodes.human_input.entities import ( FormInput, UserAction, @@ -13,7 +14,6 @@ from graphon.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) - from libs.datetime_utils import naive_utc_now from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index 86163f1554..bb3a6db1a1 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -1,8 +1,7 @@ from uuid import uuid4 -from graphon.variables import SegmentType - from factories import variable_factory +from graphon.variables import SegmentType from models import ConversationVariable diff --git a/api/tests/unit_tests/models/test_model.py b/api/tests/unit_tests/models/test_model.py index 3f6d6bfbe3..a87dd7f15a 100644 --- a/api/tests/unit_tests/models/test_model.py +++ b/api/tests/unit_tests/models/test_model.py @@ -2,9 +2,9 @@ import importlib import types import pytest -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from core.workflow.file_reference import build_file_reference +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from models.model import Conversation, Message diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index e7c0479757..f7bdc97eb5 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -3,14 +3,13 @@ import json from unittest import mock from uuid import uuid4 -from graphon.file import File, FileTransferMethod, FileType -from graphon.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from graphon.variables.segments import IntegerSegment, Segment - from constants import HIDDEN_VALUE from core.helper import encrypter from core.workflow.file_reference import build_file_reference from factories.variable_factory import build_segment +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable +from graphon.variables.segments import IntegerSegment, Segment from models.workflow import ( Workflow, WorkflowDraftVariable, diff --git a/api/tests/unit_tests/models/test_workflow_models.py b/api/tests/unit_tests/models/test_workflow_models.py index 507e1c8c3a..eb9fef7587 100644 --- a/api/tests/unit_tests/models/test_workflow_models.py +++ b/api/tests/unit_tests/models/test_workflow_models.py @@ -13,12 +13,12 @@ from datetime import UTC, datetime from uuid import uuid4 import pytest + from graphon.enums import ( BuiltinNodeTypes, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, ) - from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import ( Workflow, diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 6903c47a24..71df8c4e20 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -109,11 +109,11 @@ This test suite follows a comprehensive testing strategy that covers: from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.model_entities import ModelType from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.rag.entities import PreProcessingRule, Rule, Segmentation from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 68f4c51afe..2c7f13b79f 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -6,26 +6,15 @@ Tests are organized by functionality and include edge cases, error handling, and both positive and negative test scenarios. """ -from datetime import timedelta from unittest.mock import MagicMock, Mock, create_autospec, patch -import pytest from sqlalchemy import asc, desc from core.app.entities.app_invoke_entities import InvokeFrom from libs.datetime_utils import naive_utc_now -from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable -from models.enums import ConversationFromSource from models.model import App, Conversation, EndUser, Message from services.conversation_service import ConversationService -from services.errors.conversation import ( - ConversationNotExistsError, - ConversationVariableNotExistsError, - ConversationVariableTypeMismatchError, - LastConversationNotExistsError, -) -from services.errors.message import MessageNotExistsError class ConversationServiceTestDataFactory: @@ -338,330 +327,9 @@ class TestConversationServiceHelpers: assert condition is not None -class TestConversationServiceGetConversation: - """Test conversation retrieval operations.""" - - @patch("services.conversation_service.db.session") - def test_get_conversation_success_with_account(self, mock_db_session): - """ - Test successful conversation retrieval with account user. - - Should return conversation when found with proper filters. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_account_id=user.id, from_source=ConversationFromSource.CONSOLE - ) - - mock_db_session.scalar.return_value = conversation - - # Act - result = ConversationService.get_conversation(app_model, "conv-123", user) - - # Assert - assert result == conversation - - @patch("services.conversation_service.db.session") - def test_get_conversation_success_with_end_user(self, mock_db_session): - """ - Test successful conversation retrieval with end user. - - Should return conversation when found with proper filters for API user. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_end_user_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_end_user_id=user.id, from_source=ConversationFromSource.API - ) - - mock_db_session.scalar.return_value = conversation - - # Act - result = ConversationService.get_conversation(app_model, "conv-123", user) - - # Assert - assert result == conversation - - @patch("services.conversation_service.db.session") - def test_get_conversation_not_found_raises_error(self, mock_db_session): - """ - Test that get_conversation raises error when conversation not found. - - Should raise ConversationNotExistsError when no matching conversation found. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - - mock_db_session.scalar.return_value = None - - # Act & Assert - with pytest.raises(ConversationNotExistsError): - ConversationService.get_conversation(app_model, "conv-123", user) - - -class TestConversationServiceRename: - """Test conversation rename operations.""" - - @patch("services.conversation_service.db.session") - @patch("services.conversation_service.ConversationService.get_conversation") - def test_rename_with_manual_name(self, mock_get_conversation, mock_db_session): - """ - Test renaming conversation with manual name. - - Should update conversation name and timestamp when auto_generate is False. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - - # Act - result = ConversationService.rename( - app_model=app_model, - conversation_id="conv-123", - user=user, - name="New Name", - auto_generate=False, - ) - - # Assert - assert result == conversation - assert conversation.name == "New Name" - mock_db_session.commit.assert_called_once() - - -class TestConversationServiceAutoGenerateName: - """Test conversation auto-name generation operations.""" - - @patch("services.conversation_service.db.session") - @patch("services.conversation_service.LLMGenerator") - def test_auto_generate_name_success(self, mock_llm_generator, mock_db_session): - """ - Test successful auto-generation of conversation name. - - Should generate name using LLMGenerator and update conversation. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - message = ConversationServiceTestDataFactory.create_message_mock( - conversation_id=conversation.id, app_id=app_model.id - ) - - # Mock database query to return message - mock_db_session.scalar.return_value = message - - # Mock LLM generator - mock_llm_generator.generate_conversation_name.return_value = "Generated Name" - - # Act - result = ConversationService.auto_generate_name(app_model, conversation) - - # Assert - assert result == conversation - assert conversation.name == "Generated Name" - mock_llm_generator.generate_conversation_name.assert_called_once_with( - app_model.tenant_id, message.query, conversation.id, app_model.id - ) - mock_db_session.commit.assert_called_once() - - @patch("services.conversation_service.db.session") - def test_auto_generate_name_no_message_raises_error(self, mock_db_session): - """ - Test auto-generation fails when no message found. - - Should raise MessageNotExistsError when conversation has no messages. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - # Mock database query to return None - mock_db_session.scalar.return_value = None - - # Act & Assert - with pytest.raises(MessageNotExistsError): - ConversationService.auto_generate_name(app_model, conversation) - - @patch("services.conversation_service.db.session") - @patch("services.conversation_service.LLMGenerator") - def test_auto_generate_name_handles_llm_exception(self, mock_llm_generator, mock_db_session): - """ - Test auto-generation handles LLM generator exceptions gracefully. - - Should continue without name when LLMGenerator fails. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - message = ConversationServiceTestDataFactory.create_message_mock( - conversation_id=conversation.id, app_id=app_model.id - ) - - # Mock database query to return message - mock_db_session.scalar.return_value = message - - # Mock LLM generator to raise exception - mock_llm_generator.generate_conversation_name.side_effect = Exception("LLM Error") - - # Act - result = ConversationService.auto_generate_name(app_model, conversation) - - # Assert - assert result == conversation - # Name should remain unchanged due to exception - mock_db_session.commit.assert_called_once() - - -class TestConversationServiceDelete: - """Test conversation deletion operations.""" - - @patch("services.conversation_service.delete_conversation_related_data") - @patch("services.conversation_service.db.session") - @patch("services.conversation_service.ConversationService.get_conversation") - def test_delete_success(self, mock_get_conversation, mock_db_session, mock_delete_task): - """ - Test successful conversation deletion. - - Should delete conversation and schedule cleanup task. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock(name="Test App") - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - - # Act - ConversationService.delete(app_model, "conv-123", user) - - # Assert - mock_db_session.delete.assert_called_once_with(conversation) - mock_db_session.commit.assert_called_once() - mock_delete_task.delay.assert_called_once_with(conversation.id) - - class TestConversationServiceConversationalVariable: """Test conversational variable operations.""" - @patch("services.conversation_service.session_factory") - @patch("services.conversation_service.ConversationService.get_conversation") - def test_get_conversational_variable_success(self, mock_get_conversation, mock_session_factory): - """ - Test successful retrieval of conversational variables. - - Should return paginated list of variables for conversation. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - - # Mock session and variables - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - - variable1 = ConversationServiceTestDataFactory.create_conversation_variable_mock() - variable2 = ConversationServiceTestDataFactory.create_conversation_variable_mock(variable_id="var-456") - - mock_session.scalars.return_value.all.return_value = [variable1, variable2] - - # Act - result = ConversationService.get_conversational_variable( - app_model=app_model, - conversation_id="conv-123", - user=user, - limit=10, - last_id=None, - ) - - # Assert - assert isinstance(result, InfiniteScrollPagination) - assert len(result.data) == 2 - assert result.limit == 10 - assert result.has_more is False - - @patch("services.conversation_service.session_factory") - @patch("services.conversation_service.ConversationService.get_conversation") - def test_get_conversational_variable_with_last_id(self, mock_get_conversation, mock_session_factory): - """ - Test retrieval of variables with last_id pagination. - - Should filter variables created after last_id. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - - # Mock session and variables - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - - last_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock( - created_at=naive_utc_now() - timedelta(hours=1) - ) - variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=naive_utc_now()) - - mock_session.scalar.return_value = last_variable - mock_session.scalars.return_value.all.return_value = [variable] - - # Act - result = ConversationService.get_conversational_variable( - app_model=app_model, - conversation_id="conv-123", - user=user, - limit=10, - last_id="var-123", - ) - - # Assert - assert isinstance(result, InfiniteScrollPagination) - assert len(result.data) == 1 - assert result.limit == 10 - - @patch("services.conversation_service.session_factory") - @patch("services.conversation_service.ConversationService.get_conversation") - def test_get_conversational_variable_last_id_not_found_raises_error( - self, mock_get_conversation, mock_session_factory - ): - """ - Test that invalid last_id raises ConversationVariableNotExistsError. - - Should raise error when last_id doesn't exist. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - - # Mock session - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - mock_session.scalar.return_value = None - - # Act & Assert - with pytest.raises(ConversationVariableNotExistsError): - ConversationService.get_conversational_variable( - app_model=app_model, - conversation_id="conv-123", - user=user, - limit=10, - last_id="invalid-id", - ) - @patch("services.conversation_service.session_factory") @patch("services.conversation_service.ConversationService.get_conversation") @patch("services.conversation_service.dify_config") @@ -698,466 +366,3 @@ class TestConversationServiceConversationalVariable: # Assert - JSON filter should be applied assert mock_session.scalars.called - - @patch("services.conversation_service.session_factory") - @patch("services.conversation_service.ConversationService.get_conversation") - @patch("services.conversation_service.dify_config") - def test_get_conversational_variable_with_name_filter_postgresql( - self, mock_config, mock_get_conversation, mock_session_factory - ): - """ - Test variable filtering by name for PostgreSQL databases. - - Should apply JSON extraction filter for variable names. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - mock_config.DB_TYPE = "postgresql" - - # Mock session - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - mock_session.scalars.return_value.all.return_value = [] - - # Act - ConversationService.get_conversational_variable( - app_model=app_model, - conversation_id="conv-123", - user=user, - limit=10, - last_id=None, - variable_name="test_var", - ) - - # Assert - JSON filter should be applied - assert mock_session.scalars.called - - -class TestConversationServiceUpdateVariable: - """Test conversation variable update operations.""" - - @patch("services.conversation_service.variable_factory") - @patch("services.conversation_service.ConversationVariableUpdater") - @patch("services.conversation_service.session_factory") - @patch("services.conversation_service.ConversationService.get_conversation") - def test_update_conversation_variable_success( - self, mock_get_conversation, mock_session_factory, mock_updater_class, mock_variable_factory - ): - """ - Test successful update of conversation variable. - - Should update variable value and return updated data. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - - # Mock session and existing variable - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - - existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="string") - mock_session.scalar.return_value = existing_variable - - # Mock variable factory and updater - updated_variable = Mock() - updated_variable.model_dump.return_value = {"id": "var-123", "name": "test_var", "value": "new_value"} - mock_variable_factory.build_conversation_variable_from_mapping.return_value = updated_variable - - mock_updater = MagicMock() - mock_updater_class.return_value = mock_updater - - # Act - result = ConversationService.update_conversation_variable( - app_model=app_model, - conversation_id="conv-123", - variable_id="var-123", - user=user, - new_value="new_value", - ) - - # Assert - assert result["id"] == "var-123" - assert result["value"] == "new_value" - mock_updater.update.assert_called_once_with("conv-123", updated_variable) - mock_updater.flush.assert_called_once() - - @patch("services.conversation_service.session_factory") - @patch("services.conversation_service.ConversationService.get_conversation") - def test_update_conversation_variable_not_found_raises_error(self, mock_get_conversation, mock_session_factory): - """ - Test update fails when variable doesn't exist. - - Should raise ConversationVariableNotExistsError. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - - # Mock session - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - mock_session.scalar.return_value = None - - # Act & Assert - with pytest.raises(ConversationVariableNotExistsError): - ConversationService.update_conversation_variable( - app_model=app_model, - conversation_id="conv-123", - variable_id="invalid-id", - user=user, - new_value="new_value", - ) - - @patch("services.conversation_service.session_factory") - @patch("services.conversation_service.ConversationService.get_conversation") - def test_update_conversation_variable_type_mismatch_raises_error(self, mock_get_conversation, mock_session_factory): - """ - Test update fails when value type doesn't match expected type. - - Should raise ConversationVariableTypeMismatchError. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - - # Mock session and existing variable - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - - existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="number") - mock_session.scalar.return_value = existing_variable - - # Act & Assert - Try to set string value for number variable - with pytest.raises(ConversationVariableTypeMismatchError): - ConversationService.update_conversation_variable( - app_model=app_model, - conversation_id="conv-123", - variable_id="var-123", - user=user, - new_value="string_value", # Wrong type - ) - - @patch("services.conversation_service.session_factory") - @patch("services.conversation_service.ConversationService.get_conversation") - def test_update_conversation_variable_integer_number_compatibility( - self, mock_get_conversation, mock_session_factory - ): - """ - Test that integer type accepts number values. - - Should allow number values for integer type variables. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - - # Mock session and existing variable - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - - existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="integer") - mock_session.scalar.return_value = existing_variable - - # Mock variable factory and updater - updated_variable = Mock() - updated_variable.model_dump.return_value = {"id": "var-123", "name": "test_var", "value": 42} - - with ( - patch("services.conversation_service.variable_factory") as mock_variable_factory, - patch("services.conversation_service.ConversationVariableUpdater") as mock_updater_class, - ): - mock_variable_factory.build_conversation_variable_from_mapping.return_value = updated_variable - mock_updater = MagicMock() - mock_updater_class.return_value = mock_updater - - # Act - result = ConversationService.update_conversation_variable( - app_model=app_model, - conversation_id="conv-123", - variable_id="var-123", - user=user, - new_value=42, # Number value for integer type - ) - - # Assert - assert result["value"] == 42 - mock_updater.update.assert_called_once() - - -class TestConversationServicePaginationAdvanced: - """Advanced pagination tests for ConversationService.""" - - @patch("services.conversation_service.session_factory") - def test_pagination_by_last_id_with_last_id_not_found(self, mock_session_factory): - """ - Test pagination with invalid last_id raises error. - - Should raise LastConversationNotExistsError when last_id doesn't exist. - """ - # Arrange - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - mock_session.scalar.return_value = None - - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - - # Act & Assert - with pytest.raises(LastConversationNotExistsError): - ConversationService.pagination_by_last_id( - session=mock_session, - app_model=app_model, - user=user, - last_id="invalid-id", - limit=20, - invoke_from=InvokeFrom.WEB_APP, - ) - - @patch("services.conversation_service.session_factory") - def test_pagination_by_last_id_with_exclude_ids(self, mock_session_factory): - """ - Test pagination with exclude_ids filter. - - Should exclude specified conversation IDs from results. - """ - # Arrange - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - mock_session.scalars.return_value.all.return_value = [conversation] - mock_session.scalar.return_value = conversation - - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - - # Act - result = ConversationService.pagination_by_last_id( - session=mock_session, - app_model=app_model, - user=user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - exclude_ids=["excluded-123"], - ) - - # Assert - assert isinstance(result, InfiniteScrollPagination) - assert len(result.data) == 1 - - @patch("services.conversation_service.session_factory") - def test_pagination_by_last_id_has_more_detection(self, mock_session_factory): - """ - Test pagination has_more detection logic. - - Should set has_more=True when there are more results beyond limit. - """ - # Arrange - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - - # Return exactly limit items to trigger has_more check - conversations = [ - ConversationServiceTestDataFactory.create_conversation_mock(conversation_id=f"conv-{i}") for i in range(20) - ] - mock_session.scalars.return_value.all.return_value = conversations - mock_session.scalar.return_value = conversations[-1] - - # Mock count query to return > 0 - mock_session.scalar.return_value = 5 # Additional items exist - - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - - # Act - result = ConversationService.pagination_by_last_id( - session=mock_session, - app_model=app_model, - user=user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - ) - - # Assert - assert isinstance(result, InfiniteScrollPagination) - assert result.has_more is True - - @patch("services.conversation_service.session_factory") - def test_pagination_by_last_id_with_different_sort_by(self, mock_session_factory): - """ - Test pagination with different sort fields. - - Should handle various sort_by parameters correctly. - """ - # Arrange - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - mock_session.scalars.return_value.all.return_value = [conversation] - mock_session.scalar.return_value = conversation - - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - - # Test different sort fields - sort_fields = ["created_at", "-updated_at", "name", "-status"] - - for sort_by in sort_fields: - # Act - result = ConversationService.pagination_by_last_id( - session=mock_session, - app_model=app_model, - user=user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - sort_by=sort_by, - ) - - # Assert - assert isinstance(result, InfiniteScrollPagination) - - -class TestConversationServiceEdgeCases: - """Test edge cases and error scenarios.""" - - @patch("services.conversation_service.session_factory") - def test_pagination_with_end_user_api_source(self, mock_session_factory): - """ - Test pagination correctly handles EndUser with API source. - - Should use 'api' as from_source for EndUser instances. - """ - # Arrange - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - - conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_source=ConversationFromSource.API, from_end_user_id="user-123" - ) - mock_session.scalars.return_value.all.return_value = [conversation] - - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_end_user_mock() - - # Act - result = ConversationService.pagination_by_last_id( - session=mock_session, - app_model=app_model, - user=user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - ) - - # Assert - assert isinstance(result, InfiniteScrollPagination) - - @patch("services.conversation_service.session_factory") - def test_pagination_with_account_console_source(self, mock_session_factory): - """ - Test pagination correctly handles Account with console source. - - Should use 'console' as from_source for Account instances. - """ - # Arrange - mock_session = MagicMock() - mock_session_factory.create_session.return_value.__enter__.return_value = mock_session - - conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_source=ConversationFromSource.CONSOLE, from_account_id="account-123" - ) - mock_session.scalars.return_value.all.return_value = [conversation] - - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - - # Act - result = ConversationService.pagination_by_last_id( - session=mock_session, - app_model=app_model, - user=user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - ) - - # Assert - assert isinstance(result, InfiniteScrollPagination) - - def test_pagination_with_include_ids_filter(self): - """ - Test pagination with include_ids filter. - - Should only return conversations with IDs in include_ids list. - """ - # Arrange - mock_session = MagicMock() - mock_session.scalars.return_value.all.return_value = [] - - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - - # Act - result = ConversationService.pagination_by_last_id( - session=mock_session, - app_model=app_model, - user=user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - include_ids=["conv-123", "conv-456"], - ) - - # Assert - assert isinstance(result, InfiniteScrollPagination) - # Verify that include_ids filter was applied - assert mock_session.scalars.called - - def test_pagination_with_empty_exclude_ids(self): - """ - Test pagination with empty exclude_ids list. - - Should handle empty exclude_ids gracefully. - """ - # Arrange - mock_session = MagicMock() - mock_session.scalars.return_value.all.return_value = [] - - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - - # Act - result = ConversationService.pagination_by_last_id( - session=mock_session, - app_model=app_model, - user=user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - exclude_ids=[], - ) - - # Assert - assert isinstance(result, InfiniteScrollPagination) - assert result.has_more is False diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index 9be475d043..55af564821 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -3,18 +3,18 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest -from graphon.nodes.human_input.entities import ( - FormDefinition, - FormInput, - UserAction, -) -from graphon.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus import services.human_input_service as human_input_service_module from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) +from graphon.nodes.human_input.entities import ( + FormDefinition, + FormInput, + UserAction, +) +from graphon.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from services.human_input_service import ( diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 1bd979b9ec..97f3bd6f01 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -1,11 +1,11 @@ import types import pytest + +from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ConfigurateMethod - -from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration from models.provider import ProviderType from services.model_provider_service import ModelProviderService diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index 98ec6fb77c..4b864dd221 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -16,6 +16,7 @@ from typing import Any from uuid import uuid4 import pytest + from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segments import ( ArrayFileSegment, @@ -28,7 +29,6 @@ from graphon.variables.segments import ( ObjectSegment, StringSegment, ) - from services.variable_truncator import ( DummyVariableTruncator, MaxDepthExceededError, diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index bf645f9795..c3335e5723 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -567,7 +567,6 @@ class TestWebhookServiceUnit: from types import SimpleNamespace from typing import Any, cast -from graphon.variables.types import SegmentType from werkzeug.exceptions import RequestEntityTooLarge from core.workflow.nodes.trigger_webhook.entities import ( @@ -576,6 +575,7 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, WebhookParameter, ) +from graphon.variables.types import SegmentType from models.enums import AppTriggerStatus from models.model import App from models.trigger import WorkflowWebhookTrigger diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index a62c9f4555..239cc83518 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -13,10 +13,10 @@ from datetime import datetime from unittest.mock import MagicMock, create_autospec, patch import pytest -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowExecutionStatus from models.workflow import WorkflowPause from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py index 8525672da8..60beec7964 100644 --- a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -4,12 +4,12 @@ import json from unittest.mock import Mock, patch import pytest -from graphon.file import File, FileTransferMethod, FileType -from graphon.variables.segments import ObjectSegment, StringSegment -from graphon.variables.types import SegmentType from sqlalchemy import Engine from core.workflow.file_reference import build_file_reference +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import ObjectSegment, StringSegment +from graphon.variables.types import SegmentType from models.model import UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import DraftVarLoader diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index e7e72793a3..f6bdb6a60e 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -4,10 +4,6 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.enums import BuiltinNodeTypes -from graphon.file import File, FileTransferMethod, FileType -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType from sqlalchemy import Engine from sqlalchemy.orm import Session @@ -17,6 +13,10 @@ from core.workflow.variable_prefixes import ( ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) +from graphon.enums import BuiltinNodeTypes +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from libs.uuid_utils import uuidv7 from models.account import Account from models.enums import DraftVariableType diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index 4146fd312b..d570dce107 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -6,13 +6,13 @@ from datetime import UTC, datetime from threading import Event import pytest -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index 98d057e41f..d7192994b2 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -3,9 +3,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData from sqlalchemy.orm import sessionmaker from core.workflow.human_input_compat import ( @@ -15,6 +12,9 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index 7119217e94..591da56f49 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -5,8 +5,8 @@ from types import SimpleNamespace from typing import Any import pytest -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from tasks import human_input_timeout_tasks as task_module diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index 544e89fcee..689b973097 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -4,7 +4,6 @@ from typing import Any from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMUsage from core.mcp.types import ( AudioContent, @@ -19,6 +18,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.mcp_tool.tool import MCPTool +from graphon.model_runtime.entities.llm_entities import LLMUsage def _make_mcp_tool(output_schema: dict[str, Any] | None = None) -> MCPTool: diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index ffa6833524..c166a946d9 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -2,6 +2,9 @@ from decimal import Decimal from unittest.mock import MagicMock, patch import pytest + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, @@ -18,9 +21,6 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output - def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: """Create a mock LLMUsage with all required fields""" diff --git a/api/uv.lock b/api/uv.lock index db00ccf800..587e59c8a7 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -11,6 +11,14 @@ resolution-markers = [ [manifest] members = [ "dify-api", + "dify-trace-aliyun", + "dify-trace-arize-phoenix", + "dify-trace-langfuse", + "dify-trace-langsmith", + "dify-trace-mlflow", + "dify-trace-opik", + "dify-trace-tencent", + "dify-trace-weave", "dify-vdb-alibabacloud-mysql", "dify-vdb-analyticdb", "dify-vdb-baidu", @@ -380,14 +388,14 @@ wheels = [ [[package]] name = "authlib" -version = "1.6.9" +version = "1.6.11" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/af/98/00d3dd826d46959ad8e32af2dbb2398868fd9fd0683c26e56d0789bd0e68/authlib-1.6.9.tar.gz", hash = "sha256:d8f2421e7e5980cc1ddb4e32d3f5fa659cfaf60d8eaf3281ebed192e4ab74f04", size = 165134, upload-time = "2026-03-02T07:44:01.998Z" } +sdist = { url = "https://files.pythonhosted.org/packages/28/10/b325d58ffe86815b399334a101e63bc6fa4e1953921cb23703b48a0a0220/authlib-1.6.11.tar.gz", hash = "sha256:64db35b9b01aeccb4715a6c9a6613a06f2bd7be2ab9d2eb89edd1dfc7580a38f", size = 165359, upload-time = "2026-04-16T07:22:50.279Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/53/23/b65f568ed0c22f1efacb744d2db1a33c8068f384b8c9b482b52ebdbc3ef6/authlib-1.6.9-py2.py3-none-any.whl", hash = "sha256:f08b4c14e08f0861dc18a32357b33fbcfd2ea86cfe3fe149484b4d764c4a0ac3", size = 244197, upload-time = "2026-03-02T07:44:00.307Z" }, + { url = "https://files.pythonhosted.org/packages/57/2f/55fca558f925a51db046e5b929deb317ddb05afed74b22d89f4eca578980/authlib-1.6.11-py2.py3-none-any.whl", hash = "sha256:c8687a9a26451c51a34a06fa17bb97cb15bba46a6a626755e2d7f50da8bff3e3", size = 244469, upload-time = "2026-04-16T07:22:48.413Z" }, ] [[package]] @@ -1372,7 +1380,6 @@ version = "1.13.3" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, - { name = "arize-phoenix-otel" }, { name = "azure-identity" }, { name = "bleach" }, { name = "boto3" }, @@ -1395,9 +1402,6 @@ dependencies = [ { name = "httpx", extra = ["socks"] }, { name = "httpx-sse" }, { name = "json-repair" }, - { name = "langfuse" }, - { name = "langsmith" }, - { name = "mlflow-skinny" }, { name = "opentelemetry-distro" }, { name = "opentelemetry-instrumentation-celery" }, { name = "opentelemetry-instrumentation-flask" }, @@ -1405,7 +1409,6 @@ dependencies = [ { name = "opentelemetry-instrumentation-redis" }, { name = "opentelemetry-instrumentation-sqlalchemy" }, { name = "opentelemetry-propagator-b3" }, - { name = "opik" }, { name = "psycogreen" }, { name = "psycopg2-binary" }, { name = "python-socketio" }, @@ -1414,7 +1417,6 @@ dependencies = [ { name = "resend" }, { name = "sendgrid" }, { name = "sseclient-py" }, - { name = "weave" }, ] [package.dev-dependencies] @@ -1501,6 +1503,40 @@ tools = [ { name = "cloudscraper" }, { name = "nltk" }, ] +trace-aliyun = [ + { name = "dify-trace-aliyun" }, +] +trace-all = [ + { name = "dify-trace-aliyun" }, + { name = "dify-trace-arize-phoenix" }, + { name = "dify-trace-langfuse" }, + { name = "dify-trace-langsmith" }, + { name = "dify-trace-mlflow" }, + { name = "dify-trace-opik" }, + { name = "dify-trace-tencent" }, + { name = "dify-trace-weave" }, +] +trace-arize-phoenix = [ + { name = "dify-trace-arize-phoenix" }, +] +trace-langfuse = [ + { name = "dify-trace-langfuse" }, +] +trace-langsmith = [ + { name = "dify-trace-langsmith" }, +] +trace-mlflow = [ + { name = "dify-trace-mlflow" }, +] +trace-opik = [ + { name = "dify-trace-opik" }, +] +trace-tencent = [ + { name = "dify-trace-tencent" }, +] +trace-weave = [ + { name = "dify-trace-weave" }, +] vdb-alibabacloud-mysql = [ { name = "dify-vdb-alibabacloud-mysql" }, ] @@ -1630,7 +1666,6 @@ vdb-xinference = [ [package.metadata] requires-dist = [ { name = "aliyun-log-python-sdk", specifier = ">=0.9.44,<1.0.0" }, - { name = "arize-phoenix-otel", specifier = "~=0.15.0" }, { name = "azure-identity", specifier = ">=1.25.3,<2.0.0" }, { name = "bleach", specifier = ">=6.3.0" }, { name = "boto3", specifier = ">=1.42.88" }, @@ -1653,9 +1688,6 @@ requires-dist = [ { name = "httpx", extras = ["socks"], specifier = ">=0.28.1,<1.0.0" }, { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "json-repair", specifier = "~=0.59.2" }, - { name = "langfuse", specifier = ">=4.2.0,<5.0.0" }, - { name = "langsmith", specifier = ">=0.7.31,<1.0.0" }, - { name = "mlflow-skinny", specifier = ">=3.11.1,<4.0.0" }, { name = "opentelemetry-distro", specifier = ">=0.62b0,<1.0.0" }, { name = "opentelemetry-instrumentation-celery", specifier = ">=0.62b0,<1.0.0" }, { name = "opentelemetry-instrumentation-flask", specifier = ">=0.62b0,<1.0.0" }, @@ -1663,7 +1695,6 @@ requires-dist = [ { name = "opentelemetry-instrumentation-redis", specifier = ">=0.62b0,<1.0.0" }, { name = "opentelemetry-instrumentation-sqlalchemy", specifier = ">=0.62b0,<1.0.0" }, { name = "opentelemetry-propagator-b3", specifier = ">=1.41.0,<2.0.0" }, - { name = "opik", specifier = "~=1.11.2" }, { name = "psycogreen", specifier = ">=1.0.2" }, { name = "psycopg2-binary", specifier = ">=2.9.11" }, { name = "python-socketio", specifier = ">=5.13.0" }, @@ -1672,7 +1703,6 @@ requires-dist = [ { name = "resend", specifier = ">=2.27.0,<3.0.0" }, { name = "sendgrid", specifier = ">=6.12.5" }, { name = "sseclient-py", specifier = ">=1.8.0" }, - { name = "weave", specifier = ">=0.52.36,<1.0.0" }, ] [package.metadata.requires-dev] @@ -1759,6 +1789,24 @@ tools = [ { name = "cloudscraper", specifier = ">=1.2.71" }, { name = "nltk", specifier = ">=3.9.1" }, ] +trace-aliyun = [{ name = "dify-trace-aliyun", editable = "providers/trace/trace-aliyun" }] +trace-all = [ + { name = "dify-trace-aliyun", editable = "providers/trace/trace-aliyun" }, + { name = "dify-trace-arize-phoenix", editable = "providers/trace/trace-arize-phoenix" }, + { name = "dify-trace-langfuse", editable = "providers/trace/trace-langfuse" }, + { name = "dify-trace-langsmith", editable = "providers/trace/trace-langsmith" }, + { name = "dify-trace-mlflow", editable = "providers/trace/trace-mlflow" }, + { name = "dify-trace-opik", editable = "providers/trace/trace-opik" }, + { name = "dify-trace-tencent", editable = "providers/trace/trace-tencent" }, + { name = "dify-trace-weave", editable = "providers/trace/trace-weave" }, +] +trace-arize-phoenix = [{ name = "dify-trace-arize-phoenix", editable = "providers/trace/trace-arize-phoenix" }] +trace-langfuse = [{ name = "dify-trace-langfuse", editable = "providers/trace/trace-langfuse" }] +trace-langsmith = [{ name = "dify-trace-langsmith", editable = "providers/trace/trace-langsmith" }] +trace-mlflow = [{ name = "dify-trace-mlflow", editable = "providers/trace/trace-mlflow" }] +trace-opik = [{ name = "dify-trace-opik", editable = "providers/trace/trace-opik" }] +trace-tencent = [{ name = "dify-trace-tencent", editable = "providers/trace/trace-tencent" }] +trace-weave = [{ name = "dify-trace-weave", editable = "providers/trace/trace-weave" }] vdb-alibabacloud-mysql = [{ name = "dify-vdb-alibabacloud-mysql", editable = "providers/vdb/vdb-alibabacloud-mysql" }] vdb-all = [ { name = "dify-vdb-alibabacloud-mysql", editable = "providers/vdb/vdb-alibabacloud-mysql" }, @@ -1823,6 +1871,110 @@ vdb-vikingdb = [{ name = "dify-vdb-vikingdb", editable = "providers/vdb/vdb-viki vdb-weaviate = [{ name = "dify-vdb-weaviate", editable = "providers/vdb/vdb-weaviate" }] vdb-xinference = [{ name = "xinference-client", specifier = ">=2.4.0" }] +[[package]] +name = "dify-trace-aliyun" +version = "0.0.1" +source = { editable = "providers/trace/trace-aliyun" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-grpc" }, + { name = "opentelemetry-sdk" }, + { name = "opentelemetry-semantic-conventions" }, +] + +[package.metadata] +requires-dist = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-grpc" }, + { name = "opentelemetry-sdk" }, + { name = "opentelemetry-semantic-conventions" }, +] + +[[package]] +name = "dify-trace-arize-phoenix" +version = "0.0.1" +source = { editable = "providers/trace/trace-arize-phoenix" } +dependencies = [ + { name = "arize-phoenix-otel" }, +] + +[package.metadata] +requires-dist = [{ name = "arize-phoenix-otel", specifier = "~=0.15.0" }] + +[[package]] +name = "dify-trace-langfuse" +version = "0.0.1" +source = { editable = "providers/trace/trace-langfuse" } +dependencies = [ + { name = "langfuse" }, +] + +[package.metadata] +requires-dist = [{ name = "langfuse", specifier = ">=4.2.0,<5.0.0" }] + +[[package]] +name = "dify-trace-langsmith" +version = "0.0.1" +source = { editable = "providers/trace/trace-langsmith" } +dependencies = [ + { name = "langsmith" }, +] + +[package.metadata] +requires-dist = [{ name = "langsmith", specifier = "~=0.7.30" }] + +[[package]] +name = "dify-trace-mlflow" +version = "0.0.1" +source = { editable = "providers/trace/trace-mlflow" } +dependencies = [ + { name = "mlflow-skinny" }, +] + +[package.metadata] +requires-dist = [{ name = "mlflow-skinny", specifier = ">=3.11.1" }] + +[[package]] +name = "dify-trace-opik" +version = "0.0.1" +source = { editable = "providers/trace/trace-opik" } +dependencies = [ + { name = "opik" }, +] + +[package.metadata] +requires-dist = [{ name = "opik", specifier = "~=1.11.2" }] + +[[package]] +name = "dify-trace-tencent" +version = "0.0.1" +source = { editable = "providers/trace/trace-tencent" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-grpc" }, + { name = "opentelemetry-sdk" }, + { name = "opentelemetry-semantic-conventions" }, +] + +[package.metadata] +requires-dist = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-grpc" }, + { name = "opentelemetry-sdk" }, + { name = "opentelemetry-semantic-conventions" }, +] + +[[package]] +name = "dify-trace-weave" +version = "0.0.1" +source = { editable = "providers/trace/trace-weave" } +dependencies = [ + { name = "weave" }, +] + +[package.metadata] +requires-dist = [{ name = "weave", specifier = ">=0.52.36" }] + [[package]] name = "dify-vdb-alibabacloud-mysql" version = "0.0.1" @@ -3903,14 +4055,14 @@ wheels = [ [[package]] name = "mako" -version = "1.3.10" +version = "1.3.11" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markupsafe" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } +sdist = { url = "https://files.pythonhosted.org/packages/59/8a/805404d0c0b9f3d7a326475ca008db57aea9c5c9f2e1e39ed0faa335571c/mako-1.3.11.tar.gz", hash = "sha256:071eb4ab4c5010443152255d77db7faa6ce5916f35226eb02dc34479b6858069", size = 399811, upload-time = "2026-04-14T20:19:51.493Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, + { url = "https://files.pythonhosted.org/packages/68/a5/19d7aaa7e433713ffe881df33705925a196afb9532efc8475d26593921a6/mako-1.3.11-py3-none-any.whl", hash = "sha256:e372c6e333cf004aa736a15f425087ec977e1fcbd2966aae7f17c8dc1da27a77", size = 78503, upload-time = "2026-04-14T20:19:53.233Z" }, ] [[package]] @@ -5501,11 +5653,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.10.1" +version = "6.10.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/66/79/f2730c42ec7891a75a2fcea2eb4f356872bcbc671b711418060424796612/pypdf-6.10.1.tar.gz", hash = "sha256:62e6ca7f65aaa28b3d192addb44f97296e4be1748f57ed0f4efb2d4915841880", size = 5315704, upload-time = "2026-04-14T12:55:20.996Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/3f/9f2167401c2e94833ca3b69535bad89e533b5de75fefe4197a2c224baec2/pypdf-6.10.2.tar.gz", hash = "sha256:7d09ce108eff6bf67465d461b6ef352dcb8d84f7a91befc02f904455c6eea11d", size = 5315679, upload-time = "2026-04-15T16:37:36.978Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/04/e3aa7f1f14dbc53429cae34666261eb935d99bd61d24756ab94d7e0309da/pypdf-6.10.1-py3-none-any.whl", hash = "sha256:6331940d3bfe75b7e6601d35db7adabab5fc1d716efaeb384e3c0c3957d033de", size = 335606, upload-time = "2026-04-14T12:55:18.941Z" }, + { url = "https://files.pythonhosted.org/packages/0c/d6/1d5c60cc17bbdf37c1552d9c03862fc6d32c5836732a0415b2d637edc2d0/pypdf-6.10.2-py3-none-any.whl", hash = "sha256:aa53be9826655b51c96741e5d7983ca224d898ac0a77896e64636810517624aa", size = 336308, upload-time = "2026-04-15T16:37:34.851Z" }, ] [[package]] diff --git a/dev/pytest/pytest_unit_tests.sh b/dev/pytest/pytest_unit_tests.sh index 962532de81..012c870c19 100755 --- a/dev/pytest/pytest_unit_tests.sh +++ b/dev/pytest/pytest_unit_tests.sh @@ -13,6 +13,7 @@ PYTEST_XDIST_ARGS="${PYTEST_XDIST_ARGS:--n auto}" pytest --timeout "${PYTEST_TIMEOUT}" ${PYTEST_XDIST_ARGS} \ api/tests/unit_tests \ api/providers/vdb/*/tests/unit_tests \ + api/providers/trace/*/tests/unit_tests \ --ignore=api/tests/unit_tests/controllers # Run controller tests sequentially to avoid import race conditions diff --git a/sdks/nodejs-client/tsconfig.json b/sdks/nodejs-client/tsconfig.json index 46055447be..1e55007ed0 100644 --- a/sdks/nodejs-client/tsconfig.json +++ b/sdks/nodejs-client/tsconfig.json @@ -1,18 +1,14 @@ { + "extends": "@dify/tsconfig/node.json", "compilerOptions": { - "target": "ES2022", - "module": "ESNext", - "moduleResolution": "Bundler", + "lib": ["ES2023", "DOM", "DOM.Iterable"], "rootDir": ".", "outDir": "dist", + "noEmit": false, "declaration": true, "declarationMap": true, "sourceMap": true, - "strict": true, - "esModuleInterop": true, - "forceConsistentCasingInFileNames": true, - "skipLibCheck": true, "types": ["node"] }, - "include": ["src/**/*.ts", "tests/**/*.ts"] + "include": ["src/**/*.ts", "tests/**/*.ts", "vite.config.ts"] } diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index 5ac39f1e39..49e9431940 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -5,7 +5,6 @@ import InSiteMessageNotification from '@/app/components/app/in-site-message/noti import AmplitudeProvider from '@/app/components/base/amplitude' import GA, { GaType } from '@/app/components/base/ga' import Zendesk from '@/app/components/base/zendesk' -import GotoAnything from '@/app/components/goto-anything' import Header from '@/app/components/header' import HeaderWrapper from '@/app/components/header/header-wrapper' import ReadmePanel from '@/app/components/plugins/readme-panel' @@ -13,10 +12,15 @@ import { AppContextProvider } from '@/context/app-context-provider' import { EventEmitterContextProvider } from '@/context/event-emitter-provider' import { ModalContextProvider } from '@/context/modal-context-provider' import { ProviderContextProvider } from '@/context/provider-context-provider' +import dynamic from '@/next/dynamic' import PartnerStack from '../components/billing/partner-stack' import Splash from '../components/splash' import RoleRouteGuard from './role-route-guard' +const GotoAnything = dynamic(() => import('@/app/components/goto-anything'), { + ssr: false, +}) + const Layout = ({ children }: { children: ReactNode }) => { return ( <> diff --git a/web/app/components/app-sidebar/basic.tsx b/web/app/components/app-sidebar/basic.tsx index 29a08f8a01..2814072860 100644 --- a/web/app/components/app-sidebar/basic.tsx +++ b/web/app/components/app-sidebar/basic.tsx @@ -76,8 +76,8 @@ export default function AppBasic({ icon, icon_background, name, isExternal, type )} {mode === 'expand' && (
-
-
+
+
{name}
{hoverTip @@ -95,10 +95,10 @@ export default function AppBasic({ icon, icon_background, name, isExternal, type )}
{!hideType && isExtraInLine && ( -
{type}
+
{type}
)} {!hideType && !isExtraInLine && ( -
{isExternal ? t('externalTag', { ns: 'dataset' }) : type}
+
{isExternal ? t('externalTag', { ns: 'dataset' }) : type}
)}
)} diff --git a/web/app/components/app-sidebar/dataset-info/menu-item.tsx b/web/app/components/app-sidebar/dataset-info/menu-item.tsx index 7ad8d9407f..d426512176 100644 --- a/web/app/components/app-sidebar/dataset-info/menu-item.tsx +++ b/web/app/components/app-sidebar/dataset-info/menu-item.tsx @@ -22,7 +22,7 @@ const MenuItem = ({ }} > - {name} + {name}
) } diff --git a/web/app/components/app/app-publisher/__tests__/index.spec.tsx b/web/app/components/app/app-publisher/__tests__/index.spec.tsx index 86b45a2a79..a7fad12a27 100644 --- a/web/app/components/app/app-publisher/__tests__/index.spec.tsx +++ b/web/app/components/app/app-publisher/__tests__/index.spec.tsx @@ -657,178 +657,4 @@ describe('AppPublisher', () => { expect(sectionProps.summary?.workflowTypeSwitchDisabled).toBe(true) expect(sectionProps.summary?.workflowTypeSwitchDisabledReason).toBe('common.switchToEvaluationWorkflowDisabledTip') }) - - it('should switch workflow type, refresh app detail, and close the popover for published apps', async () => { - mockFetchAppDetailDirect.mockResolvedValueOnce({ - id: 'app-1', - type: AppTypeEnum.EVALUATION, - }) - - render( - , - ) - - fireEvent.click(screen.getByText('common.publish')) - fireEvent.click(screen.getByText('publisher-switch-workflow-type')) - - await waitFor(() => { - expect(mockConvertWorkflowType).toHaveBeenCalledWith({ - params: { appId: 'app-1' }, - query: { target_type: AppTypeEnum.EVALUATION }, - }) - expect(mockFetchAppDetailDirect).toHaveBeenCalledWith({ url: '/apps', id: 'app-1' }) - expect(mockSetAppDetail).toHaveBeenCalledWith({ - id: 'app-1', - type: AppTypeEnum.EVALUATION, - }) - }) - expect(screen.queryByText('publisher-summary-publish')).not.toBeInTheDocument() - }) - - it('should hide access and actions sections for evaluation workflow apps', () => { - mockAppDetail = { - ...mockAppDetail, - type: AppTypeEnum.EVALUATION, - } - - render( - , - ) - - fireEvent.click(screen.getByText('common.publish')) - - expect(screen.getByText('publisher-summary-publish')).toBeInTheDocument() - expect(screen.queryByText('publisher-access-control')).not.toBeInTheDocument() - expect(screen.queryByText('publisher-embed')).not.toBeInTheDocument() - expect(sectionProps.summary?.workflowTypeSwitchConfig).toEqual({ - targetType: AppTypeEnum.WORKFLOW, - publishLabelKey: 'common.publishAsStandardWorkflow', - switchLabelKey: 'common.switchToStandardWorkflow', - tipKey: 'common.switchToStandardWorkflowTip', - }) - }) - - it('should confirm before switching an evaluation workflow with associated targets to a standard workflow', async () => { - mockAppDetail = { - ...mockAppDetail, - type: AppTypeEnum.EVALUATION, - } - mockEvaluationWorkflowAssociatedTargets = { - items: [ - { - target_type: 'app', - target_id: 'dependent-app-1', - target_name: 'Dependent App', - }, - { - target_type: 'knowledge_base', - target_id: 'knowledge-1', - target_name: 'Knowledge Base', - }, - ], - } - mockRefetchEvaluationWorkflowAssociatedTargets.mockResolvedValueOnce({ - data: mockEvaluationWorkflowAssociatedTargets, - isError: false, - }) - - render( - , - ) - - fireEvent.click(screen.getByText('common.publish')) - fireEvent.click(screen.getByText('publisher-switch-workflow-type')) - - await waitFor(() => { - expect(mockRefetchEvaluationWorkflowAssociatedTargets).toHaveBeenCalledTimes(1) - }) - expect(mockConvertWorkflowType).not.toHaveBeenCalled() - expect(screen.getByText('Dependent App')).toBeInTheDocument() - expect(screen.getByText('Knowledge Base')).toBeInTheDocument() - - fireEvent.click(screen.getByRole('button', { name: 'common.switchToStandardWorkflowConfirm.switch' })) - - await waitFor(() => { - expect(mockConvertWorkflowType).toHaveBeenCalledWith({ - params: { appId: 'app-1' }, - query: { target_type: AppTypeEnum.WORKFLOW }, - }) - }) - }) - - it('should switch an evaluation workflow directly when there are no associated targets', async () => { - mockAppDetail = { - ...mockAppDetail, - type: AppTypeEnum.EVALUATION, - } - - render( - , - ) - - fireEvent.click(screen.getByText('common.publish')) - fireEvent.click(screen.getByText('publisher-switch-workflow-type')) - - await waitFor(() => { - expect(mockRefetchEvaluationWorkflowAssociatedTargets).toHaveBeenCalledTimes(1) - expect(mockConvertWorkflowType).toHaveBeenCalledWith({ - params: { appId: 'app-1' }, - query: { target_type: AppTypeEnum.WORKFLOW }, - }) - }) - expect(screen.queryByText('common.switchToStandardWorkflowConfirm.title')).not.toBeInTheDocument() - }) - - it('should block switching an evaluation workflow when associated targets fail to load', async () => { - mockAppDetail = { - ...mockAppDetail, - type: AppTypeEnum.EVALUATION, - } - mockRefetchEvaluationWorkflowAssociatedTargets.mockResolvedValueOnce({ - data: undefined, - isError: true, - }) - - render( - , - ) - - fireEvent.click(screen.getByText('common.publish')) - fireEvent.click(screen.getByText('publisher-switch-workflow-type')) - - await waitFor(() => { - expect(mockToastError).toHaveBeenCalledWith('common.switchToStandardWorkflowConfirm.loadFailed') - }) - expect(mockConvertWorkflowType).not.toHaveBeenCalled() - }) - - it('should block switching to evaluation workflow when restricted nodes exist', async () => { - render( - , - ) - - fireEvent.click(screen.getByText('common.publish')) - fireEvent.click(screen.getByText('publisher-switch-workflow-type')) - - await waitFor(() => { - expect(mockToastError).toHaveBeenCalledWith('common.switchToEvaluationWorkflowDisabledTip') - }) - - expect(mockConvertWorkflowType).not.toHaveBeenCalled() - expect(sectionProps.summary?.workflowTypeSwitchDisabled).toBe(true) - expect(sectionProps.summary?.workflowTypeSwitchDisabledReason).toBe('common.switchToEvaluationWorkflowDisabledTip') - }) }) diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index c5e5fffaa8..e62451baca 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -256,7 +256,7 @@ const AppPublisher = ({ throw new Error('App not found') const { installed_apps } = await fetchInstalledAppList(appDetail.id) if (installed_apps?.length > 0) - return `${basePath}/explore/installed/${installed_apps[0].id}` + return `${basePath}/explore/installed/${installed_apps[0]!.id}` throw new Error('No app found in Explore') }, { onError: (err) => { diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index 2783f66c3f..1de6e6ce0c 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -96,8 +96,8 @@ const AdvancedPromptInput: FC = ({ }, onValidateBeforeSaveCallback: (newExternalDataTool: ExternalDataTool) => { for (let i = 0; i < promptVariables.length; i++) { - if (promptVariables[i].key === newExternalDataTool.variable) { - toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key })) + if (promptVariables[i]!.key === newExternalDataTool.variable) { + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i]!.key })) return false } } diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 0f6d2b94a6..6b5c3acccb 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -94,8 +94,8 @@ const Prompt: FC = ({ }, onValidateBeforeSaveCallback: (newExternalDataTool: ExternalDataTool) => { for (let i = 0; i < promptVariables.length; i++) { - if (promptVariables[i].key === newExternalDataTool.variable) { - toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key })) + if (promptVariables[i]!.key === newExternalDataTool.variable) { + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i]!.key })) return false } } diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index 951680c035..aca2817249 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -165,8 +165,8 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar }, onValidateBeforeSaveCallback: (newExternalDataTool: ExternalDataTool) => { for (let i = 0; i < promptVariables.length; i++) { - if (promptVariables[i].key === newExternalDataTool.variable && i !== index) { - toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key })) + if (promptVariables[i]!.key === newExternalDataTool.variable && i !== index) { + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i]!.key })) return false } } @@ -220,7 +220,7 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar const handleRemoveVar = useCallback((index: number) => { const removeVar = promptVariables[index] - if (mode === AppModeEnum.COMPLETION && dataSets.length > 0 && removeVar.is_context_var) { + if (mode === AppModeEnum.COMPLETION && dataSets.length > 0 && removeVar!.is_context_var) { showDeleteContextVarModal() setRemoveIndex(index) return diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx index d4ce935a4d..e5080f26e4 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx @@ -1,12 +1,11 @@ -import type { CSSProperties } from 'react' import { noop } from 'es-toolkit/function' import { memo } from 'react' import { useTranslation } from 'react-i18next' import { Slider } from '@/app/components/base/ui/slider' -const weightedScoreSliderStyle: CSSProperties & Record<'--slider-track' | '--slider-range', string> = { - '--slider-track': 'var(--color-util-colors-teal-teal-500)', - '--slider-range': 'var(--color-util-colors-blue-light-blue-light-500)', +const weightedScoreSliderSlotClassNames = { + track: 'bg-util-colors-teal-teal-500', + indicator: 'bg-util-colors-blue-light-blue-light-500', } const formatNumber = (value: number) => { @@ -36,8 +35,8 @@ const WeightedScore = ({ return (
-
-
+
+
!readonly && onChange({ value: [v, (10 - v * 10) / 10] })} disabled={readonly} aria-label={t('weightedScore.semantic', { ns: 'dataset' })} + slotClassNames={weightedScoreSliderSlotClassNames} />
-
+
{t('weightedScore.semantic', { ns: 'dataset' })}
- {formatNumber(value.value[0])} + {formatNumber(value.value[0]!)}
-
- {formatNumber(value.value[1])} +
+ {formatNumber(value.value[1]!)}
{t('weightedScore.keyword', { ns: 'dataset' })}
diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx index 4b21616d46..2e535baeac 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx @@ -84,7 +84,7 @@ const DebugItem: FC = ({ style={style} >
-
+
# {index + 1}
@@ -115,7 +115,7 @@ const DebugItem: FC = ({ {showRemove && ( <> {(showDuplicate || showDebugAsSingleModel) && } - + {t('operation.remove', { ns: 'common' })} diff --git a/web/app/components/app/workflow-log/__tests__/list.spec.tsx b/web/app/components/app/workflow-log/__tests__/list.spec.tsx index 1246096cf5..a99b3b9ce5 100644 --- a/web/app/components/app/workflow-log/__tests__/list.spec.tsx +++ b/web/app/components/app/workflow-log/__tests__/list.spec.tsx @@ -182,7 +182,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(container.querySelector('.spin-animation')).toBeInTheDocument() + expect(container.querySelector('.spin-animation'))!.toBeInTheDocument() }) it('should render loading state when appDetail is undefined', () => { @@ -192,7 +192,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(container.querySelector('.spin-animation')).toBeInTheDocument() + expect(container.querySelector('.spin-animation'))!.toBeInTheDocument() }) it('should render table when data is available', () => { @@ -202,7 +202,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should render all table headers', () => { @@ -212,12 +212,12 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('appLog.table.header.startTime')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.status')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.runtime')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.tokens')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.user')).toBeInTheDocument() - expect(screen.getByText('appLog.table.header.evaluation')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.startTime'))!.toBeInTheDocument() + expect(screen.getByText('appLog.table.header.status'))!.toBeInTheDocument() + expect(screen.getByText('appLog.table.header.runtime'))!.toBeInTheDocument() + expect(screen.getByText('appLog.table.header.tokens'))!.toBeInTheDocument() + expect(screen.getByText('appLog.table.header.user'))!.toBeInTheDocument() + expect(screen.getByText('appLog.table.header.evaluation'))!.toBeInTheDocument() }) it('should render trigger column for workflow apps', () => { @@ -228,7 +228,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('appLog.table.header.triggered_from')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.triggered_from'))!.toBeInTheDocument() }) it('should not render trigger column for non-workflow apps', () => { @@ -258,7 +258,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('Success')).toBeInTheDocument() + expect(screen.getByText('Success'))!.toBeInTheDocument() }) it('should render failure status correctly', () => { @@ -272,7 +272,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('Failure')).toBeInTheDocument() + expect(screen.getByText('Failure'))!.toBeInTheDocument() }) it('should render stopped status correctly', () => { @@ -286,7 +286,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('Stop')).toBeInTheDocument() + expect(screen.getByText('Stop'))!.toBeInTheDocument() }) it('should render running status correctly', () => { @@ -300,7 +300,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('Running')).toBeInTheDocument() + expect(screen.getByText('Running'))!.toBeInTheDocument() }) it('should render partial-succeeded status correctly', () => { @@ -314,7 +314,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('Partial Success')).toBeInTheDocument() + expect(screen.getByText('Partial Success'))!.toBeInTheDocument() }) }) @@ -334,7 +334,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('John Doe')).toBeInTheDocument() + expect(screen.getByText('John Doe'))!.toBeInTheDocument() }) it('should display end user session id when created by end user', () => { @@ -349,7 +349,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('session-abc-123')).toBeInTheDocument() + expect(screen.getByText('session-abc-123'))!.toBeInTheDocument() }) it('should display N/A when no user info', () => { @@ -364,7 +364,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('N/A')).toBeInTheDocument() + expect(screen.getByText('N/A'))!.toBeInTheDocument() }) }) @@ -406,9 +406,8 @@ describe('WorkflowAppLogList', () => { // Arrow should rotate (indicated by class change) // The sort icon should have rotate-180 class for ascending - const sortIcon = startTimeHeader.closest('div')?.querySelector('.i-heroicons-arrow-down') - expect(sortIcon).toBeInTheDocument() - expect(sortIcon).toHaveClass('rotate-180') + const sortIcon = startTimeHeader.closest('div')?.querySelector('svg') + expect(sortIcon)!.toBeInTheDocument() }) it('should render sort arrow icon', () => { @@ -419,8 +418,8 @@ describe('WorkflowAppLogList', () => { ) // Check for ArrowDownIcon presence - const sortArrow = container.querySelector('.i-heroicons-arrow-down') - expect(sortArrow).toBeInTheDocument() + const sortArrow = container.querySelector('svg.ml-0\\.5') + expect(sortArrow)!.toBeInTheDocument() }) }) @@ -443,11 +442,11 @@ describe('WorkflowAppLogList', () => { ) const dataRows = screen.getAllByRole('row') - await user.click(dataRows[1]) // Click first data row + await user.click(dataRows[1]!) // Click first data row const dialog = await screen.findByRole('dialog') - expect(dialog).toBeInTheDocument() - expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() + expect(dialog)!.toBeInTheDocument() + expect(screen.getByText('appLog.runDetail.workflowTitle'))!.toBeInTheDocument() }) it('should close drawer and call onRefresh when closing', async () => { @@ -462,7 +461,7 @@ describe('WorkflowAppLogList', () => { // Open drawer const dataRows = screen.getAllByRole('row') - await user.click(dataRows[1]) + await user.click(dataRows[1]!) await screen.findByRole('dialog') // Close drawer using Escape key @@ -485,14 +484,46 @@ describe('WorkflowAppLogList', () => { const dataRows = screen.getAllByRole('row') const dataRow = dataRows[1] + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight + // Before click - no highlight // Before click - no highlight expect(dataRow).not.toHaveClass('bg-background-default-hover') // After click - has highlight (via currentLog state) - await user.click(dataRow) + await user.click(dataRow!) // The row should have the selected class - expect(dataRow).toHaveClass('bg-background-default-hover') + // The row should have the selected class + expect(dataRow)!.toHaveClass('bg-background-default-hover') }) it('should open evaluation popover without opening drawer when clicking evaluation trigger', async () => { @@ -546,7 +577,7 @@ describe('WorkflowAppLogList', () => { // Open drawer const dataRows = screen.getAllByRole('row') - await user.click(dataRows[1]) + await user.click(dataRows[1]!) await screen.findByRole('dialog') // Replay button should be present for app-run triggers @@ -574,12 +605,12 @@ describe('WorkflowAppLogList', () => { // Open drawer const dataRows = screen.getAllByRole('row') - await user.click(dataRows[1]) + await user.click(dataRows[1]!) await screen.findByRole('dialog') // Replay button should be present for debugging triggers const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) - expect(replayButton).toBeInTheDocument() + expect(replayButton)!.toBeInTheDocument() }) it('should not show replay for webhook triggers', async () => { @@ -600,9 +631,40 @@ describe('WorkflowAppLogList', () => { // Open drawer const dataRows = screen.getAllByRole('row') - await user.click(dataRows[1]) + await user.click(dataRows[1]!) await screen.findByRole('dialog') + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers + // Replay button should not be present for webhook triggers // Replay button should not be present for webhook triggers expect(screen.queryByRole('button', { name: 'appLog.runDetail.testWithParams' })).not.toBeInTheDocument() }) @@ -625,7 +687,7 @@ describe('WorkflowAppLogList', () => { // Unread indicator is a small blue dot const unreadDot = container.querySelector('.bg-util-colors-blue-blue-500') - expect(unreadDot).toBeInTheDocument() + expect(unreadDot)!.toBeInTheDocument() }) it('should not show unread indicator for read logs', () => { @@ -660,7 +722,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('1.235s')).toBeInTheDocument() + expect(screen.getByText('1.235s'))!.toBeInTheDocument() }) it('should display 0 elapsed time with special styling', () => { @@ -675,8 +737,8 @@ describe('WorkflowAppLogList', () => { ) const zeroTime = screen.getByText('0.000s') - expect(zeroTime).toBeInTheDocument() - expect(zeroTime).toHaveClass('text-text-quaternary') + expect(zeroTime)!.toBeInTheDocument() + expect(zeroTime)!.toHaveClass('text-text-quaternary') }) }) @@ -695,7 +757,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('12345')).toBeInTheDocument() + expect(screen.getByText('12345'))!.toBeInTheDocument() }) }) @@ -711,7 +773,7 @@ describe('WorkflowAppLogList', () => { ) const table = screen.getByRole('table') - expect(table).toBeInTheDocument() + expect(table)!.toBeInTheDocument() // Should only have header row const rows = screen.getAllByRole('row') @@ -752,8 +814,8 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('0.000s')).toBeInTheDocument() - expect(screen.getByText('0')).toBeInTheDocument() + expect(screen.getByText('0.000s'))!.toBeInTheDocument() + expect(screen.getByText('0'))!.toBeInTheDocument() }) it('should handle null workflow_run.triggered_from for non-workflow apps', () => { @@ -770,6 +832,37 @@ describe('WorkflowAppLogList', () => { , ) + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column + // Should render without trigger column // Should render without trigger column expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() }) diff --git a/web/app/components/apps/__tests__/list.spec.tsx b/web/app/components/apps/__tests__/list.spec.tsx index 7a421dfa0f..02499385e7 100644 --- a/web/app/components/apps/__tests__/list.spec.tsx +++ b/web/app/components/apps/__tests__/list.spec.tsx @@ -392,6 +392,26 @@ describe('List', () => { expect(mockUseInfiniteAppList).toHaveBeenCalledWith(expect.objectContaining({ creator_id: 'user-1,user-2', }), expect.any(Object)) + const clearButton = document.querySelector('.group') + expect(clearButton)!.toBeInTheDocument() + if (clearButton) + fireEvent.click(clearButton) + + expect(mockSetQuery).toHaveBeenCalled() + }) + }) + + describe('Tag Filter', () => { + it('should render tag filter component', () => { + renderList() + expect(screen.getByText('common.tag.placeholder'))!.toBeInTheDocument() + }) + }) + + describe('Created By Me Filter', () => { + it('should render checkbox with correct label', () => { + renderList() + expect(screen.getByText('app.showMyCreatedAppsOnly'))!.toBeInTheDocument() }) it('should handle checkbox change', () => { @@ -446,39 +466,39 @@ describe('List', () => { describe('Edge Cases', () => { it('should handle multiple renders without issues', () => { const { unmount } = renderWithNuqs() - expect(screen.getByText('app.types.all')).toBeInTheDocument() + expect(screen.getByText('app.types.all'))!.toBeInTheDocument() unmount() renderList() - expect(screen.getByText('app.types.all')).toBeInTheDocument() + expect(screen.getByText('app.types.all'))!.toBeInTheDocument() }) it('should render app cards correctly', () => { renderList() - expect(screen.getByText('Test App 1')).toBeInTheDocument() - expect(screen.getByText('Test App 2')).toBeInTheDocument() + expect(screen.getByText('Test App 1'))!.toBeInTheDocument() + expect(screen.getByText('Test App 2'))!.toBeInTheDocument() }) it('should render with all filter options visible', () => { renderList() - expect(screen.getByRole('textbox')).toBeInTheDocument() - expect(screen.getByText('common.tag.placeholder')).toBeInTheDocument() - expect(screen.getByText('app.showMyCreatedAppsOnly')).toBeInTheDocument() + expect(screen.getByRole('textbox'))!.toBeInTheDocument() + expect(screen.getByText('common.tag.placeholder'))!.toBeInTheDocument() + expect(screen.getByText('app.showMyCreatedAppsOnly'))!.toBeInTheDocument() }) }) describe('Dragging State', () => { it('should show drop hint when DSL feature is enabled for editors', () => { renderList() - expect(screen.getByText('app.newApp.dropDSLToCreateApp')).toBeInTheDocument() + expect(screen.getByText('app.newApp.dropDSLToCreateApp'))!.toBeInTheDocument() }) it('should render dragging state overlay when dragging', () => { mockDragging = true const { container } = renderList() - expect(container).toBeInTheDocument() + expect(container)!.toBeInTheDocument() }) }) @@ -486,12 +506,12 @@ describe('List', () => { it('should render all app type tabs', () => { renderList() - expect(screen.getByText('app.types.all')).toBeInTheDocument() - expect(screen.getByText('app.types.workflow')).toBeInTheDocument() - expect(screen.getByText('app.types.advanced')).toBeInTheDocument() - expect(screen.getByText('app.types.chatbot')).toBeInTheDocument() - expect(screen.getByText('app.types.agent')).toBeInTheDocument() - expect(screen.getByText('app.types.completion')).toBeInTheDocument() + expect(screen.getByText('app.types.all'))!.toBeInTheDocument() + expect(screen.getByText('app.types.workflow'))!.toBeInTheDocument() + expect(screen.getByText('app.types.advanced'))!.toBeInTheDocument() + expect(screen.getByText('app.types.chatbot'))!.toBeInTheDocument() + expect(screen.getByText('app.types.agent'))!.toBeInTheDocument() + expect(screen.getByText('app.types.completion'))!.toBeInTheDocument() }) it('should update URL for each app type tab click', async () => { @@ -509,7 +529,7 @@ describe('List', () => { onUrlUpdate.mockClear() fireEvent.click(screen.getByText(text)) await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] expect(lastCall.searchParams.get('category')).toBe(mode) } }) @@ -519,22 +539,22 @@ describe('List', () => { it('should display all app cards from data', () => { renderList() - expect(screen.getByTestId('app-card-app-1')).toBeInTheDocument() - expect(screen.getByTestId('app-card-app-2')).toBeInTheDocument() + expect(screen.getByTestId('app-card-app-1'))!.toBeInTheDocument() + expect(screen.getByTestId('app-card-app-2'))!.toBeInTheDocument() }) it('should display app names correctly', () => { renderList() - expect(screen.getByText('Test App 1')).toBeInTheDocument() - expect(screen.getByText('Test App 2')).toBeInTheDocument() + expect(screen.getByText('Test App 1'))!.toBeInTheDocument() + expect(screen.getByText('Test App 2'))!.toBeInTheDocument() }) }) describe('Footer Visibility', () => { it('should render footer when branding is disabled', () => { renderList() - expect(screen.getByTestId('footer')).toBeInTheDocument() + expect(screen.getByTestId('footer'))!.toBeInTheDocument() }) }) @@ -548,7 +568,20 @@ describe('List', () => { mockOnDSLFileDropped(mockFile) }) - expect(screen.getByTestId('create-dsl-modal')).toBeInTheDocument() + expect(screen.getByTestId('create-dsl-modal'))!.toBeInTheDocument() + }) + + it('should close DSL modal when onClose is called', () => { + renderList() + + const mockFile = new File(['test content'], 'test.yml', { type: 'application/yaml' }) + act(() => { + if (mockOnDSLFileDropped) + mockOnDSLFileDropped(mockFile) + }) + + expect(screen.getByTestId('create-dsl-modal'))!.toBeInTheDocument() + fireEvent.click(screen.getByTestId('close-dsl-modal')) expect(screen.queryByTestId('create-dsl-modal')).not.toBeInTheDocument() }) @@ -584,6 +617,30 @@ describe('List', () => { }) expect(mockFetchSnippetNextPage).toHaveBeenCalled() + expect(screen.getByTestId('create-dsl-modal'))!.toBeInTheDocument() + + fireEvent.click(screen.getByTestId('success-dsl-modal')) + + expect(screen.queryByTestId('create-dsl-modal')).not.toBeInTheDocument() + expect(mockRefetch).toHaveBeenCalled() + }) + }) + + describe('Infinite Scroll', () => { + it('should call fetchNextPage when intersection observer triggers', () => { + mockServiceState.hasNextPage = true + renderList() + + if (intersectionCallback) { + act(() => { + intersectionCallback!( + [{ isIntersecting: true } as IntersectionObserverEntry], + {} as IntersectionObserver, + ) + }) + } + + expect(mockFetchNextPage).toHaveBeenCalled() }) it('should not render app-only controls in snippets mode', () => { @@ -623,4 +680,11 @@ describe('List', () => { expect(screen.getByTestId('empty-state')).toHaveTextContent('workflow.tabs.noSnippetsFound') }) }) + describe('Error State', () => { + it('should handle error state in useEffect', () => { + mockServiceState.error = new Error('Test error') + const { container } = renderList() + expect(container)!.toBeInTheDocument() + }) + }) }) diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 7b7dcd1c8b..7fd8a6bd35 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -2,11 +2,10 @@ import type { FC } from 'react' import type { StudioPageType } from '.' -import type { App } from '@/types/app' import type { WorkflowOnlineUser } from '@/models/app' import { cn } from '@langgenius/dify-ui/cn' import { useDebounceFn } from 'ahooks' -import { parseAsStringLiteral, useQueryState } from 'nuqs' +import { useQueryState } from 'nuqs' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Input from '@/app/components/base/input' @@ -23,7 +22,6 @@ import { useInfiniteAppList } from '@/service/use-apps' import { useInfiniteSnippetList } from '@/service/use-snippets' import SnippetCard from '../snippets/components/snippet-card' import SnippetCreateCard from '../snippets/components/snippet-create-card' -import { AppModeEnum, AppModes } from '@/types/app' import AppCard from './app-card' import { AppCardSkeleton } from './app-card-skeleton' import AppTypeFilter from './app-type-filter' @@ -173,7 +171,7 @@ const List: FC = ({ const dynamicMargin = Math.max(100, Math.min(containerHeight * 0.2, 200)) observer = new IntersectionObserver((entries) => { - if (entries[0].isIntersecting && !isPageLoading && !isNextPageFetching && !currentError && hasMore) { + if (entries[0]!.isIntersecting && !isPageLoading && !isNextPageFetching && !currentError && hasMore) { if (isAppsPage) fetchNextPage() else @@ -218,10 +216,6 @@ const List: FC = ({ handleTagsUpdate(value) }, [handleTagsUpdate]) - const appItems = useMemo(() => { - return (data?.pages ?? []).flatMap(({ data: apps }) => apps) - }, [data?.pages]) - const snippetItems = useMemo(() => { return (snippetData?.pages ?? []).flatMap(({ data }) => data) }, [snippetData?.pages]) @@ -288,7 +282,8 @@ const List: FC = ({ <>
{dragging && ( -
+
+
)}
diff --git a/web/app/components/base/audio-gallery/AudioPlayer.tsx b/web/app/components/base/audio-gallery/AudioPlayer.tsx index 9174b13356..c3b2056698 100644 --- a/web/app/components/base/audio-gallery/AudioPlayer.tsx +++ b/web/app/components/base/audio-gallery/AudioPlayer.tsx @@ -95,7 +95,7 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { for (let i = 0; i < samples; i++) { let sum = 0 for (let j = 0; j < blockSize; j++) - sum += Math.abs(channelData[i * blockSize + j]) + sum += Math.abs(channelData[i * blockSize + j]!) // Apply nonlinear scaling to enhance small amplitudes waveformData.push((sum / blockSize) * 5) } @@ -145,7 +145,7 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { e.preventDefault() const getClientX = (event: React.MouseEvent | React.TouchEvent): number => { if ('touches' in event) - return event.touches[0].clientX + return event.touches[0]!.clientX return event.clientX } const updateProgress = (clientX: number) => { diff --git a/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx b/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx index bd5f01bcda..83a8666e79 100644 --- a/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx @@ -151,8 +151,8 @@ describe('ChatWrapper', () => { render() - expect(await screen.findByText('Welcome')).toBeInTheDocument() - expect(await screen.findByText('Q1')).toBeInTheDocument() + expect(await screen.findByText('Welcome'))!.toBeInTheDocument() + expect(await screen.findByText('Q1'))!.toBeInTheDocument() fireEvent.click(screen.getByText('Q1')) expect(handleSend).toHaveBeenCalled() @@ -170,7 +170,7 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(screen.getByText('Default opening statement')).toBeInTheDocument() + expect(screen.getByText('Default opening statement'))!.toBeInTheDocument() }) it('should render welcome screen without suggested questions', async () => { @@ -186,7 +186,7 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(await screen.findByText('Welcome message')).toBeInTheDocument() + expect(await screen.findByText('Welcome message'))!.toBeInTheDocument() }) it('should show responding state', async () => { @@ -197,7 +197,7 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(await screen.findByText('Bot thinking...')).toBeInTheDocument() + expect(await screen.findByText('Bot thinking...'))!.toBeInTheDocument() }) it('should handle manual message input and stop responding', async () => { @@ -320,9 +320,9 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const disabledContainer = chatInput.closest('.pointer-events-none') - expect(disabledContainer).toBeInTheDocument() - expect(disabledContainer).toHaveClass('opacity-50') + const disabledContainer = chatInput!.closest('.pointer-events-none') + expect(disabledContainer)!.toBeInTheDocument() + expect(disabledContainer)!.toHaveClass('opacity-50') }) it('should not disable input when required field has value', () => { @@ -337,7 +337,7 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput.closest('.pointer-events-none') + const container = chatInput!.closest('.pointer-events-none') expect(container).not.toBeInTheDocument() }) @@ -361,8 +361,8 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput.closest('.pointer-events-none') - expect(container).toBeInTheDocument() + const container = chatInput!.closest('.pointer-events-none') + expect(container)!.toBeInTheDocument() }) it('should not disable input when file is fully uploaded', () => { @@ -411,8 +411,8 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput.closest('.pointer-events-none') - expect(container).toBeInTheDocument() + const container = chatInput!.closest('.pointer-events-none') + expect(container)!.toBeInTheDocument() }) it('should not disable when all files are uploaded', () => { @@ -457,7 +457,7 @@ describe('ChatWrapper', () => { render() const textarea = screen.getByRole('textbox') const container = textarea.closest('.pointer-events-none') - expect(container).toBeInTheDocument() + expect(container)!.toBeInTheDocument() }) it('should not disable input when allInputsHidden is true', () => { @@ -523,7 +523,7 @@ describe('ChatWrapper', () => { render() expect(handleSwitchSibling).toHaveBeenCalledWith('resume-node', expect.any(Object)) - const resumeOptions = handleSwitchSibling.mock.calls[0][1] + const resumeOptions = handleSwitchSibling.mock.calls[0]![1] resumeOptions.onGetSuggestedQuestions('response-from-resume') expect(fetchSuggestedQuestions).toHaveBeenCalledWith('response-from-resume', 'webApp', 'test-app-id') }) @@ -619,7 +619,7 @@ describe('ChatWrapper', () => { render() - const onStopCallback = vi.mocked(useChat).mock.calls[0][3] as (taskId: string) => void + const onStopCallback = vi.mocked(useChat).mock.calls[0]![3] as (taskId: string) => void onStopCallback('taskId-123') expect(stopChatMessageResponding).toHaveBeenCalledWith('', 'taskId-123', 'webApp', 'test-app-id') }) @@ -645,7 +645,7 @@ describe('ChatWrapper', () => { expect(handleSend).toHaveBeenCalled() // Get the options passed to handleSend - const options = handleSend.mock.calls[0][2] + const options = handleSend.mock.calls[0]![2] expect(options.isPublicAPI).toBe(true) // Call onGetSuggestedQuestions @@ -679,7 +679,7 @@ describe('ChatWrapper', () => { fireEvent.click(nextButton) expect(handleSwitchSibling).toHaveBeenCalled() - const options = handleSwitchSibling.mock.calls[0][1] + const options = handleSwitchSibling.mock.calls[0]![1] options.onGetSuggestedQuestions('response-id') expect(fetchSuggestedQuestions).toHaveBeenCalledWith('response-id', 'webApp', 'test-app-id') } @@ -708,8 +708,8 @@ describe('ChatWrapper', () => { expect(handleSend).toHaveBeenCalled() const args = handleSend.mock.calls[0] // args[1] is data - expect(args[1].query).toBe('Q1') - expect(args[1].parent_message_id).toBeNull() + expect(args![1].query).toBe('Q1') + expect(args![1].parent_message_id).toBeNull() } }) @@ -737,7 +737,7 @@ describe('ChatWrapper', () => { fireEvent.click(regenerateBtn) expect(handleSend).toHaveBeenCalled() const args = handleSend.mock.calls[0] - expect(args[1].parent_message_id).toBe('a0') + expect(args![1].parent_message_id).toBe('a0') } }) @@ -774,10 +774,10 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(await screen.findByText('Node 1')).toBeInTheDocument() + expect(await screen.findByText('Node 1'))!.toBeInTheDocument() const input = screen.getAllByRole('textbox').find(el => el.closest('.chat-answer-container')) || screen.getAllByRole('textbox')[0] - fireEvent.change(input, { target: { value: 'test' } }) + fireEvent.change(input!, { target: { value: 'test' } }) const runButton = screen.getByText('Run') fireEvent.click(runButton) @@ -817,10 +817,10 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(await screen.findByText('Node Web 1')).toBeInTheDocument() + expect(await screen.findByText('Node Web 1'))!.toBeInTheDocument() const input = screen.getAllByRole('textbox').find(el => el.closest('.chat-answer-container')) || screen.getAllByRole('textbox')[0] - fireEvent.change(input, { target: { value: 'web-test' } }) + fireEvent.change(input!, { target: { value: 'web-test' } }) fireEvent.click(screen.getByText('Run')) await waitFor(() => { @@ -841,7 +841,7 @@ describe('ChatWrapper', () => { render() expect(document.querySelector('.chat-answer-container')).not.toBeInTheDocument() - expect(screen.getByText('Welcome')).toBeInTheDocument() + expect(screen.getByText('Welcome'))!.toBeInTheDocument() }) it('should show all messages including opening statement when there are multiple messages', () => { @@ -861,7 +861,7 @@ describe('ChatWrapper', () => { render() const welcomeElements = screen.getAllByText('Welcome') expect(welcomeElements.length).toBeGreaterThan(0) - expect(screen.getByText('User message')).toBeInTheDocument() + expect(screen.getByText('User message'))!.toBeInTheDocument() }) it('should show chatNode and inputs form on desktop for new conversation', () => { @@ -873,7 +873,7 @@ describe('ChatWrapper', () => { }) render() - expect(screen.getByText('Test')).toBeInTheDocument() + expect(screen.getByText('Test'))!.toBeInTheDocument() }) it('should show chatNode on mobile for new conversation only', () => { @@ -885,7 +885,7 @@ describe('ChatWrapper', () => { }) const { rerender } = render() - expect(screen.getByText('Test')).toBeInTheDocument() + expect(screen.getByText('Test'))!.toBeInTheDocument() vi.mocked(useChatWithHistoryContext).mockReturnValue({ ...defaultContextValue, @@ -974,8 +974,8 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(screen.getByText('Answer')).toBeInTheDocument() - expect(screen.getByAltText('answer icon')).toBeInTheDocument() + expect(screen.getByText('Answer'))!.toBeInTheDocument() + expect(screen.getByAltText('answer icon'))!.toBeInTheDocument() }) it('should render question icon fallback when user avatar is available', () => { @@ -993,7 +993,7 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(screen.getByText('J')).toBeInTheDocument() + expect(screen.getByText('J'))!.toBeInTheDocument() }) it('should use fallback values for nullable appData, appMeta and avatar name', () => { @@ -1012,8 +1012,8 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(screen.getByText('Question with fallback avatar name')).toBeInTheDocument() - expect(screen.getByText('U')).toBeInTheDocument() + expect(screen.getByText('Question with fallback avatar name'))!.toBeInTheDocument() + expect(screen.getByText('U'))!.toBeInTheDocument() }) it('should set handleStop on currentChatInstanceRef', () => { @@ -1101,8 +1101,8 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput.closest('.pointer-events-none') - expect(container).toBeInTheDocument() + const container = chatInput!.closest('.pointer-events-none') + expect(container)!.toBeInTheDocument() }) it('should call formatBooleanInputs when sending message', async () => { @@ -1223,7 +1223,8 @@ describe('ChatWrapper', () => { render() // This tests line 91 - using currentConversationItem.introduction - expect(screen.getByText('Custom introduction from conversation item')).toBeInTheDocument() + // This tests line 91 - using currentConversationItem.introduction + expect(screen.getByText('Custom introduction from conversation item'))!.toBeInTheDocument() }) it('should handle early return when hasEmptyInput is already set', () => { @@ -1242,8 +1243,8 @@ describe('ChatWrapper', () => { // This tests line 106 - early return when hasEmptyInput is set const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput.closest('.pointer-events-none') - expect(container).toBeInTheDocument() + const container = chatInput!.closest('.pointer-events-none') + expect(container)!.toBeInTheDocument() }) it('should handle early return when fileIsUploading is already set', () => { @@ -1270,8 +1271,8 @@ describe('ChatWrapper', () => { // This tests line 109 - early return when fileIsUploading is set const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput.closest('.pointer-events-none') - expect(container).toBeInTheDocument() + const container = chatInput!.closest('.pointer-events-none') + expect(container)!.toBeInTheDocument() }) it('should handle doSend with no parent message id', async () => { @@ -1561,7 +1562,7 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(screen.getByText('Default opening statement')).toBeInTheDocument() + expect(screen.getByText('Default opening statement'))!.toBeInTheDocument() }) it('should handle doSend when regenerating with null parentAnswer', async () => { @@ -1609,7 +1610,9 @@ describe('ChatWrapper', () => { // Just verify the component renders - the actual editedQuestion flow // is tested through the doRegenerate callback that's passed to Chat - expect(screen.getByText('Answer')).toBeInTheDocument() + // Just verify the component renders - the actual editedQuestion flow + // is tested through the doRegenerate callback that's passed to Chat + expect(screen.getByText('Answer'))!.toBeInTheDocument() expect(handleSend).toBeDefined() }) @@ -1629,7 +1632,9 @@ describe('ChatWrapper', () => { // The doRegenerate is passed to Chat component and would be called // This ensures lines 198-200 are covered - expect(screen.getByText('A1')).toBeInTheDocument() + // The doRegenerate is passed to Chat component and would be called + // This ensures lines 198-200 are covered + expect(screen.getByText('A1'))!.toBeInTheDocument() }) it('should handle doRegenerate when question has message_files', async () => { @@ -1809,7 +1814,38 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput.closest('.pointer-events-none') + const container = chatInput!.closest('.pointer-events-none') + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required + // Should not be disabled because it's not required // Should not be disabled because it's not required expect(container).not.toBeInTheDocument() }) diff --git a/web/app/components/base/chat/chat-with-history/header/__tests__/index.spec.tsx b/web/app/components/base/chat/chat-with-history/header/__tests__/index.spec.tsx index 5feaccd191..b1c23a129b 100644 --- a/web/app/components/base/chat/chat-with-history/header/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/header/__tests__/index.spec.tsx @@ -108,7 +108,7 @@ describe('Header Component', () => { currentConversationItem: mockConv, sidebarCollapseState: true, }) - expect(screen.getByText('My Chat')).toBeInTheDocument() + expect(screen.getByText('My Chat'))!.toBeInTheDocument() }) it('should render ViewFormDropdown trigger when inputsForms are present', () => { @@ -133,7 +133,7 @@ describe('Header Component', () => { const buttons = screen.getAllByRole('button') // Sidebar, NewChat, ResetChat (3) const resetChatBtn = buttons[buttons.length - 1] - await userEvent.click(resetChatBtn) + await userEvent.click(resetChatBtn!) expect(handleNewConversation).toHaveBeenCalled() }) @@ -144,7 +144,7 @@ describe('Header Component', () => { const buttons = screen.getAllByRole('button') const sidebarBtn = buttons[0] - await userEvent.click(sidebarBtn) + await userEvent.click(sidebarBtn!) expect(handleSidebarCollapse).toHaveBeenCalledWith(false) }) @@ -163,7 +163,7 @@ describe('Header Component', () => { await userEvent.click(trigger) const pinBtn = await screen.findByText('explore.sidebar.action.pin') - expect(pinBtn).toBeInTheDocument() + expect(pinBtn)!.toBeInTheDocument() await userEvent.click(pinBtn) @@ -225,7 +225,7 @@ describe('Header Component', () => { const renameMenuBtn = await screen.findByText('explore.sidebar.action.rename') await userEvent.click(renameMenuBtn) - expect(await screen.findByText('common.chat.renameConversation')).toBeInTheDocument() + expect(await screen.findByText('common.chat.renameConversation'))!.toBeInTheDocument() const input = screen.getByDisplayValue('My Chat') await userEvent.clear(input) @@ -236,7 +236,7 @@ describe('Header Component', () => { expect(handleRenameConversation).toHaveBeenCalledWith('conv-1', 'New Name', expect.any(Object)) - const successCallback = handleRenameConversation.mock.calls[0][2].onSuccess + const successCallback = handleRenameConversation.mock.calls[0]![2].onSuccess await act(async () => { successCallback() }) @@ -262,14 +262,14 @@ describe('Header Component', () => { await userEvent.click(deleteMenuBtn) expect(handleDeleteConversation).not.toHaveBeenCalled() - expect(await screen.findByText('share.chat.deleteConversation.title')).toBeInTheDocument() + expect(await screen.findByText('share.chat.deleteConversation.title'))!.toBeInTheDocument() const confirmBtn = await screen.findByText('common.operation.confirm') await userEvent.click(confirmBtn) expect(handleDeleteConversation).toHaveBeenCalledWith('conv-1', expect.any(Object)) - const successCallback = handleDeleteConversation.mock.calls[0][1].onSuccess + const successCallback = handleDeleteConversation.mock.calls[0]![1].onSuccess await act(async () => { successCallback() }) @@ -311,7 +311,7 @@ describe('Header Component', () => { await userEvent.click(screen.getByText('My Chat')) await userEvent.click(await screen.findByText('explore.sidebar.action.delete')) - expect(await screen.findByText('share.chat.deleteConversation.title')).toBeInTheDocument() + expect(await screen.findByText('share.chat.deleteConversation.title'))!.toBeInTheDocument() }) }) @@ -332,7 +332,7 @@ describe('Header Component', () => { it('should render system title if conversation id is missing', () => { setup({ currentConversationId: '', sidebarCollapseState: true }) const titleEl = screen.getByText('Test App') - expect(titleEl).toHaveClass('system-md-semibold') + expect(titleEl)!.toHaveClass('system-md-semibold') }) it('should render app icon from URL when icon_url is provided', () => { @@ -347,7 +347,7 @@ describe('Header Component', () => { }, }) const img = screen.getByAltText('app icon') - expect(img).toHaveAttribute('src', 'https://example.com/icon.png') + expect(img)!.toHaveAttribute('src', 'https://example.com/icon.png') }) it('should handle undefined appData gracefully (optional chaining)', () => { @@ -364,7 +364,8 @@ describe('Header Component', () => { sidebarCollapseState: true, }) // The separator is just a div with text content '/' - expect(screen.getByText('/')).toBeInTheDocument() + // The separator is just a div with text content '/' + expect(screen.getByText('/'))!.toBeInTheDocument() }) it('should handle New Chat button state when currentConversationId is present but isResponding is true', () => { @@ -377,7 +378,7 @@ describe('Header Component', () => { const buttons = screen.getAllByRole('button') // Sidebar, NewChat, ResetChat (3) const newChatBtn = buttons[1] - expect(newChatBtn).toBeDisabled() + expect(newChatBtn)!.toBeDisabled() }) it('should handle New Chat button state when currentConversationId is missing and isResponding is false', () => { @@ -390,7 +391,7 @@ describe('Header Component', () => { const buttons = screen.getAllByRole('button') // Sidebar, NewChat (2) const newChatBtn = buttons[1] - expect(newChatBtn).toBeDisabled() + expect(newChatBtn)!.toBeDisabled() }) it('should not render operation menu if conversation id is missing', () => { diff --git a/web/app/components/base/chat/chat-with-history/header/operation.tsx b/web/app/components/base/chat/chat-with-history/header/operation.tsx index a6dd6a0a9e..d439a43c1f 100644 --- a/web/app/components/base/chat/chat-with-history/header/operation.tsx +++ b/web/app/components/base/chat/chat-with-history/header/operation.tsx @@ -71,7 +71,7 @@ const Operation: FC = ({ )} {isShowDelete && ( handleDeferredAction(onDelete)} > diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index e6f5657ff5..df261d750c 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -452,7 +452,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) setOriginConversationList(produce((draft) => { const index = originConversationList.findIndex(item => item.id === conversationId) - const item = draft[index] + const item = draft[index]! draft[index] = { ...item, name: newName, diff --git a/web/app/components/base/chat/chat-with-history/inputs-form/__tests__/content.spec.tsx b/web/app/components/base/chat/chat-with-history/inputs-form/__tests__/content.spec.tsx index c1a0f3e294..6081024490 100644 --- a/web/app/components/base/chat/chat-with-history/inputs-form/__tests__/content.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/inputs-form/__tests__/content.spec.tsx @@ -248,6 +248,19 @@ describe('InputsFormContent', () => { expect(mockSetCurrentConversationInputs).toHaveBeenCalledWith(expect.objectContaining({ sel: 'A' })) }) + it('renders select dropdown above the settings dialog layer', async () => { + const user = userEvent.setup() + const context = createMockContext({ + inputsForms: [{ variable: 'sel', type: InputVarType.select, label: 'Sel', options: ['A', 'B'], default: 'B' }], + currentConversationInputs: {}, + }) + + renderWithContext(, context) + await user.click(screen.getByText('B')) + + expect(screen.getByText('A').closest('.z-\\[60\\]')).not.toBeNull() + }) + it('handles select input with existing value (value not in options -> shows placeholder)', () => { const context = createMockContext({ inputsForms: [{ variable: 'sel', type: InputVarType.select, label: 'Sel', options: ['A'], default: undefined }], diff --git a/web/app/components/base/chat/chat-with-history/inputs-form/content.tsx b/web/app/components/base/chat/chat-with-history/inputs-form/content.tsx index 127cf2c252..4baa46744d 100644 --- a/web/app/components/base/chat/chat-with-history/inputs-form/content.tsx +++ b/web/app/components/base/chat/chat-with-history/inputs-form/content.tsx @@ -86,7 +86,7 @@ const InputsFormContent = ({ showTip }: Props) => { )} {form.type === InputVarType.select && ( ({ value: option, name: option }))} onSelect={item => handleFormChange(form.variable, item.value as string)} diff --git a/web/app/components/base/chat/chat-with-history/sidebar/operation.tsx b/web/app/components/base/chat/chat-with-history/sidebar/operation.tsx index 611d2bb1b9..adda03fb55 100644 --- a/web/app/components/base/chat/chat-with-history/sidebar/operation.tsx +++ b/web/app/components/base/chat/chat-with-history/sidebar/operation.tsx @@ -105,7 +105,7 @@ const Operation: FC = ({ )} {isShowDelete && ( { e.stopPropagation() diff --git a/web/app/components/base/chat/chat/citation/popup.tsx b/web/app/components/base/chat/chat/citation/popup.tsx index 2b4070b69a..51a73bc4b6 100644 --- a/web/app/components/base/chat/chat/citation/popup.tsx +++ b/web/app/components/base/chat/chat/citation/popup.tsx @@ -64,10 +64,10 @@ const Popup: FC = ({
-
+
-
+
{(data.dataSourceType === 'upload_file' || data.dataSourceType === 'file') && !!data.sources?.[0]?.dataset_id ? ( ) - expect(screen.getByRole('button', { name: 'Run' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Run' }))!.toBeInTheDocument() }) it('should render correctly without optional className props', () => { const { wrapper, canvasLayer, gradientLayer, contentLayer } = renderGridMask({}, Plain child) - expect(wrapper).toHaveClass('bg-saas-background') - expect(canvasLayer).toHaveClass('absolute') - expect(gradientLayer).toHaveClass('absolute') - expect(contentLayer).toHaveTextContent('Plain child') + expect(wrapper)!.toHaveClass('bg-saas-background') + expect(canvasLayer)!.toHaveClass('absolute') + expect(gradientLayer)!.toHaveClass('absolute') + expect(contentLayer)!.toHaveTextContent('Plain child') }) it('should render wrapper, canvas, gradient and content layers in order', () => { const { wrapper, canvasLayer, gradientLayer, contentLayer } = renderGridMask({}, Content) - expect(wrapper).toBeInTheDocument() + expect(wrapper)!.toBeInTheDocument() expect(wrapper.children).toHaveLength(3) - expect(canvasLayer).toHaveClass('z-0') - expect(gradientLayer).toHaveClass('z-1') - expect(contentLayer).toHaveClass('z-2') - expect(contentLayer).toHaveTextContent('Content') + expect(canvasLayer)!.toHaveClass('z-0') + expect(gradientLayer)!.toHaveClass('z-1') + expect(contentLayer)!.toHaveClass('z-2') + expect(contentLayer)!.toHaveTextContent('Content') }) }) describe('Props', () => { it('should apply wrapperClassName to wrapper element', () => { const { wrapper } = renderGridMask({ wrapperClassName: 'custom-wrapper' }, Child) - expect(wrapper).toHaveClass('custom-wrapper') - expect(wrapper).toHaveClass('relative') + expect(wrapper)!.toHaveClass('custom-wrapper') + expect(wrapper)!.toHaveClass('relative') }) it('should apply canvasClassName and grid background class to canvas layer', () => { const { canvasLayer } = renderGridMask({ canvasClassName: 'custom-canvas' }, Child) - expect(canvasLayer).toHaveClass('custom-canvas') - expect(canvasLayer).toHaveClass(Style.gridBg) + expect(canvasLayer)!.toHaveClass('custom-canvas') + expect(canvasLayer)!.toHaveClass(Style.gridBg!) }) it('should apply gradientClassName to gradient layer', () => { const { gradientLayer } = renderGridMask({ gradientClassName: 'custom-gradient' }, Child) - expect(gradientLayer).toHaveClass('custom-gradient') - expect(gradientLayer).toHaveClass('bg-grid-mask-background') + expect(gradientLayer)!.toHaveClass('custom-gradient') + expect(gradientLayer)!.toHaveClass('bg-grid-mask-background') }) }) }) diff --git a/web/app/components/base/image-uploader/hooks.ts b/web/app/components/base/image-uploader/hooks.ts index 8a1a082b0f..ec1b1248f1 100644 --- a/web/app/components/base/image-uploader/hooks.ts +++ b/web/app/components/base/image-uploader/hooks.ts @@ -31,7 +31,7 @@ export const useImageFiles = () => { const files = filesRef.current const index = files.findIndex(file => file._id === imageFileId) if (index > -1) { - const currentFile = files[index] + const currentFile = files[index]! const newFiles = [...files.slice(0, index), { ...currentFile, deleted: true }, ...files.slice(index + 1)] setFiles(newFiles) filesRef.current = newFiles @@ -41,7 +41,7 @@ export const useImageFiles = () => { const files = filesRef.current const index = files.findIndex(file => file._id === imageFileId) if (index > -1) { - const currentFile = files[index] + const currentFile = files[index]! const newFiles = [...files.slice(0, index), { ...currentFile, progress: -1 }, ...files.slice(index + 1)] filesRef.current = newFiles setFiles(newFiles) @@ -51,7 +51,7 @@ export const useImageFiles = () => { const files = filesRef.current const index = files.findIndex(file => file._id === imageFileId) if (index > -1) { - const currentImageFile = files[index] + const currentImageFile = files[index]! const newFiles = [...files.slice(0, index), { ...currentImageFile, progress: 100 }, ...files.slice(index + 1)] filesRef.current = newFiles setFiles(newFiles) @@ -61,9 +61,9 @@ export const useImageFiles = () => { const files = filesRef.current const index = files.findIndex(file => file._id === imageFileId) if (index > -1) { - const currentImageFile = files[index] + const currentImageFile = files[index]! imageUpload({ - file: currentImageFile.file!, + file: currentImageFile!.file!, onProgressCallback: (progress) => { const newFiles = [...files.slice(0, index), { ...currentImageFile, progress }, ...files.slice(index + 1)] filesRef.current = newFiles @@ -114,7 +114,7 @@ export const useLocalFileUploader = ({ limit, disabled = false, onUpload }: useL // TODO: leave some warnings? return } - if (!ALLOW_FILE_EXTENSIONS.includes(file.type.split('/')[1])) + if (!ALLOW_FILE_EXTENSIONS.includes(file.type.split('/')[1]!)) return if (limit && file.size > limit * 1024 * 1024) { toast.error(t('imageUploader.uploadFromComputerLimit', { ns: 'common', size: limit })) diff --git a/web/app/components/base/image-uploader/image-list.stories.tsx b/web/app/components/base/image-uploader/image-list.stories.tsx index cfea4a0da0..0c27211d16 100644 --- a/web/app/components/base/image-uploader/image-list.stories.tsx +++ b/web/app/components/base/image-uploader/image-list.stories.tsx @@ -132,7 +132,7 @@ const ImageUploaderPlayground = ({ readonly }: Story['args']) => { return (
- Add images + Add images
)} diff --git a/web/app/components/base/notion-icon/index.tsx b/web/app/components/base/notion-icon/index.tsx index 62fcef1dc1..f2b5146d73 100644 --- a/web/app/components/base/notion-icon/index.tsx +++ b/web/app/components/base/notion-icon/index.tsx @@ -31,7 +31,7 @@ const NotionIcon = ({ ) } return ( -
{name?.[0].toLocaleUpperCase()}
+
{name?.[0]!.toLocaleUpperCase()}
) } diff --git a/web/app/components/base/notion-page-selector/page-selector/__tests__/index.spec.tsx b/web/app/components/base/notion-page-selector/page-selector/__tests__/index.spec.tsx index d4b559452e..21a7a08d63 100644 --- a/web/app/components/base/notion-page-selector/page-selector/__tests__/index.spec.tsx +++ b/web/app/components/base/notion-page-selector/page-selector/__tests__/index.spec.tsx @@ -24,11 +24,11 @@ const mockList: DataSourceNotionPage[] = [ ] const mockPagesMap: DataSourceNotionPageMap = { - 'root-1': { ...mockList[0], workspace_id: 'workspace-1' }, - 'child-1': { ...mockList[1], workspace_id: 'workspace-1' }, - 'grandchild-1': { ...mockList[2], workspace_id: 'workspace-1' }, - 'child-2': { ...mockList[3], workspace_id: 'workspace-1' }, - 'root-2': { ...mockList[4], workspace_id: 'workspace-1' }, + 'root-1': { ...mockList[0]!, workspace_id: 'workspace-1' }, + 'child-1': { ...mockList[1]!, workspace_id: 'workspace-1' }, + 'grandchild-1': { ...mockList[2]!, workspace_id: 'workspace-1' }, + 'child-2': { ...mockList[3]!, workspace_id: 'workspace-1' }, + 'root-2': { ...mockList[4]!, workspace_id: 'workspace-1' }, } describe('PageSelector', () => { @@ -39,7 +39,7 @@ describe('PageSelector', () => { it('should render root level pages initially', () => { render() - expect(screen.getByText('Root 1')).toBeInTheDocument() + expect(screen.getByText('Root 1'))!.toBeInTheDocument() expect(screen.queryByText('Child 1')).not.toBeInTheDocument() }) @@ -50,13 +50,13 @@ describe('PageSelector', () => { const toggle = screen.getByTestId('notion-page-toggle-root-1') await user.click(toggle) - expect(screen.getByText('Child 1')).toBeInTheDocument() + expect(screen.getByText('Child 1'))!.toBeInTheDocument() }) it('should call onSelect with descendants when parent is selected', async () => { const handleSelect = vi.fn() const user = userEvent.setup() - render() + render() const checkbox = screen.getByTestId('checkbox-notion-page-checkbox-root-1') await user.click(checkbox) @@ -78,7 +78,7 @@ describe('PageSelector', () => { it('should show breadcrumbs when searching', () => { render() - expect(screen.getByText('Root 1 / Child 1 / Grandchild 1')).toBeInTheDocument() + expect(screen.getByText('Root 1 / Child 1 / Grandchild 1'))!.toBeInTheDocument() }) it('should call onPreview when preview button is clicked', async () => { @@ -95,7 +95,7 @@ describe('PageSelector', () => { it('should show no result message when search returns nothing', () => { render() - expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument() }) it('should handle selection when searchValue is present', async () => { @@ -124,7 +124,7 @@ describe('PageSelector', () => { const toggleBtn = screen.getByTestId('notion-page-toggle-root-1') await user.click(toggleBtn) // Expand - await waitFor(() => expect(screen.queryByText('Child 1')).toBeInTheDocument()) + await waitFor(() => expect(screen.queryByText('Child 1'))!.toBeInTheDocument()) await user.click(toggleBtn) // Collapse await waitFor(() => expect(screen.queryByText('Child 1')).not.toBeInTheDocument()) @@ -149,14 +149,14 @@ describe('PageSelector', () => { it('should render preview button when canPreview is true', () => { render() - expect(screen.getByTestId('notion-page-preview-root-1')).toBeInTheDocument() + expect(screen.getByTestId('notion-page-preview-root-1'))!.toBeInTheDocument() }) it('should use previewPageId prop when provided', () => { const { rerender } = render() let row = screen.getByTestId('notion-page-row-root-1') - expect(row).toHaveClass('bg-state-base-hover') + expect(row)!.toHaveClass('bg-state-base-hover') rerender() @@ -190,8 +190,9 @@ describe('PageSelector', () => { await user.click(toggle) // Both children should be visible - expect(screen.getByText('Child 1')).toBeInTheDocument() - expect(screen.getByText('Child 2')).toBeInTheDocument() + // Both children should be visible + expect(screen.getByText('Child 1'))!.toBeInTheDocument() + expect(screen.getByText('Child 2'))!.toBeInTheDocument() }) it('should expand nested children when toggling parent', async () => { @@ -201,12 +202,12 @@ describe('PageSelector', () => { // Expand root-1 let toggle = screen.getByTestId('notion-page-toggle-root-1') await user.click(toggle) - expect(screen.getByText('Child 1')).toBeInTheDocument() + expect(screen.getByText('Child 1'))!.toBeInTheDocument() // Expand child-1 toggle = screen.getByTestId('notion-page-toggle-child-1') await user.click(toggle) - expect(screen.getByText('Grandchild 1')).toBeInTheDocument() + expect(screen.getByText('Grandchild 1'))!.toBeInTheDocument() // Collapse child-1 await user.click(toggle) @@ -227,7 +228,7 @@ describe('PageSelector', () => { it('should only select the item when searching (no descendants)', async () => { const handleSelect = vi.fn() const user = userEvent.setup() - render() + render() const checkbox = screen.getByTestId('checkbox-notion-page-checkbox-child-1') await user.click(checkbox) @@ -239,7 +240,7 @@ describe('PageSelector', () => { it('should deselect only the item when searching (no descendants)', async () => { const handleSelect = vi.fn() const user = userEvent.setup() - render() + render() const checkbox = screen.getByTestId('checkbox-notion-page-checkbox-child-1') await user.click(checkbox) @@ -250,8 +251,8 @@ describe('PageSelector', () => { it('should handle multiple root pages', async () => { render() - expect(screen.getByText('Root 1')).toBeInTheDocument() - expect(screen.getByText('Root 2')).toBeInTheDocument() + expect(screen.getByText('Root 1'))!.toBeInTheDocument() + expect(screen.getByText('Root 2'))!.toBeInTheDocument() }) it('should update preview when clicking preview button with onPreview provided', async () => { @@ -276,29 +277,61 @@ describe('PageSelector', () => { rerender() const row = screen.getByTestId('notion-page-row-root-1') - expect(row).toHaveClass('bg-state-base-hover') + expect(row)!.toHaveClass('bg-state-base-hover') }) it('should render page name with correct title attribute', () => { render() const pageName = screen.getByTestId('notion-page-name-root-1') - expect(pageName).toHaveAttribute('title', 'Root 1') + expect(pageName)!.toHaveAttribute('title', 'Root 1') }) it('should handle empty list gracefully', () => { render() - expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument() }) it('should filter search results correctly with partial matches', () => { render() // Should show Root 1, Child 1, and Grandchild 1 - expect(screen.getByTestId('notion-page-name-root-1')).toBeInTheDocument() - expect(screen.getByTestId('notion-page-name-child-1')).toBeInTheDocument() - expect(screen.getByTestId('notion-page-name-grandchild-1')).toBeInTheDocument() + // Should show Root 1, Child 1, and Grandchild 1 + expect(screen.getByTestId('notion-page-name-root-1'))!.toBeInTheDocument() + expect(screen.getByTestId('notion-page-name-child-1'))!.toBeInTheDocument() + expect(screen.getByTestId('notion-page-name-grandchild-1'))!.toBeInTheDocument() + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 + // Should not show Root 2, Child 2 // Should not show Root 2, Child 2 expect(screen.queryByTestId('notion-page-name-root-2')).not.toBeInTheDocument() expect(screen.queryByTestId('notion-page-name-child-2')).not.toBeInTheDocument() @@ -313,6 +346,7 @@ describe('PageSelector', () => { await user.click(toggle) // Should expand even though parent is disabled - expect(screen.getByText('Child 1')).toBeInTheDocument() + // Should expand even though parent is disabled + expect(screen.getByText('Child 1'))!.toBeInTheDocument() }) }) diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/prompt-option.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/prompt-option.tsx index a36403b898..1499fc1d7f 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/prompt-option.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/prompt-option.tsx @@ -23,7 +23,7 @@ export const PromptMenuItem = memo(({ className={` flex h-6 cursor-pointer items-center rounded-md px-3 hover:bg-state-base-hover ${isSelected && !disabled && 'bg-state-base-hover!'} - ${disabled ? 'cursor-not-allowed opacity-30' : 'cursor-pointer hover:bg-state-base-hover'} + ${disabled ? 'cursor-not-allowed opacity-30' : ''} `} tabIndex={-1} ref={setRefElement} diff --git a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/component.spec.tsx b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/component.spec.tsx index f219f2f805..ee82595d1c 100644 --- a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/component.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/component.spec.tsx @@ -100,8 +100,8 @@ describe('HITLInputComponent', () => { await user.click(screen.getByRole('button', { name: 'emit-same-name' })) expect(onChange).toHaveBeenCalledTimes(1) - expect(onChange.mock.calls[0][0]).toHaveLength(1) - expect(onChange.mock.calls[0][0][0].output_variable_name).toBe('user_name') + expect(onChange.mock.calls[0]![0]).toHaveLength(1) + expect(onChange.mock.calls[0]![0][0].output_variable_name).toBe('user_name') }) it('should replace payload when variable name is renamed', async () => { @@ -124,7 +124,7 @@ describe('HITLInputComponent', () => { await user.click(screen.getByRole('button', { name: 'emit-rename' })) expect(onChange).toHaveBeenCalledTimes(1) - expect(onChange.mock.calls[0][0][0].output_variable_name).toBe('renamed_name') + expect(onChange.mock.calls[0]![0][0].output_variable_name).toBe('renamed_name') }) it('should update existing payload when variable name stays the same', async () => { @@ -157,9 +157,9 @@ describe('HITLInputComponent', () => { await user.click(screen.getByRole('button', { name: 'emit-update' })) expect(onChange).toHaveBeenCalledTimes(1) - expect(onChange.mock.calls[0][0][0].default.value).toBe('updated') - expect(onChange.mock.calls[0][0][0].output_variable_name).toBe('user_name') - expect(onChange.mock.calls[0][0][1].output_variable_name).toBe('other_name') - expect(onChange.mock.calls[0][0][1].default.value).toBe('other') + expect(onChange.mock.calls[0]![0][0].default.value).toBe('updated') + expect(onChange.mock.calls[0]![0][0].output_variable_name).toBe('user_name') + expect(onChange.mock.calls[0]![0][1].output_variable_name).toBe('other_name') + expect(onChange.mock.calls[0]![0][1].default.value).toBe('other') }) }) diff --git a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/pre-populate.spec.tsx b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/pre-populate.spec.tsx index f5efc52c23..990a7ced4a 100644 --- a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/pre-populate.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/pre-populate.spec.tsx @@ -82,12 +82,12 @@ describe('PrePopulate', () => { />, ) - expect(screen.getByText('Static Content')).toBeInTheDocument() + expect(screen.getByText('Static Content'))!.toBeInTheDocument() await user.keyboard('{Tab}') expect(screen.queryByText('Static Content')).not.toBeInTheDocument() - expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('textbox'))!.toBeInTheDocument() }) it('should update constant value and toggle to variable mode when type switch is clicked', async () => { @@ -154,7 +154,7 @@ describe('PrePopulate', () => { />, ) - const pickerProps = mockVarReferencePicker.mock.calls[0][0] as VarReferencePickerProps + const pickerProps = mockVarReferencePicker.mock.calls[0]![0] as VarReferencePickerProps const allowString = pickerProps.filterVar({ type: 'string' } as Var) const allowNumber = pickerProps.filterVar({ type: 'number' } as Var) diff --git a/web/app/components/base/prompt-editor/plugins/query-block/component.tsx b/web/app/components/base/prompt-editor/plugins/query-block/component.tsx index a5b5969904..cd5b60bc9b 100644 --- a/web/app/components/base/prompt-editor/plugins/query-block/component.tsx +++ b/web/app/components/base/prompt-editor/plugins/query-block/component.tsx @@ -17,7 +17,7 @@ const QueryBlockComponent: FC = ({ return (
{ />, ) - expect(screen.getByRole('button', { name: 'label' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'label' }))!.toBeInTheDocument() expect(mockHasNodes).toHaveBeenCalledWith([WorkflowVariableBlockNode]) expect(mockRegisterCommand).toHaveBeenCalledWith( UPDATE_WORKFLOW_NODES_MAP, @@ -188,7 +188,7 @@ describe('WorkflowVariableBlockComponent', () => { />, ) - expect(screen.getByRole('button', { name: 'label' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'label' }))!.toBeInTheDocument() }) it('should pass computed varType when getVarType is provided', () => { @@ -489,7 +489,7 @@ describe('WorkflowVariableBlockComponent', () => { />, ) - const updateHandler = mockRegisterCommand.mock.calls[0][1] as (payload: UpdateWorkflowNodesMapPayload) => boolean + const updateHandler = mockRegisterCommand.mock.calls[0]![1] as (payload: UpdateWorkflowNodesMapPayload) => boolean let result = false act(() => { result = updateHandler({ diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx index 2d13627b20..20fc7c6e79 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx @@ -19,11 +19,11 @@ export class WorkflowVariableBlockNode extends DecoratorNode __getVarType?: GetVarType __availableVariables?: NodeOutPutVar[] - static getType(): string { + static override getType(): string { return 'workflow-variable-block' } - static clone(node: WorkflowVariableBlockNode): WorkflowVariableBlockNode { + static override clone(node: WorkflowVariableBlockNode): WorkflowVariableBlockNode { return new WorkflowVariableBlockNode( node.__variables, node.__workflowNodesMap, @@ -33,7 +33,7 @@ export class WorkflowVariableBlockNode extends DecoratorNode ) } - isInline(): boolean { + override isInline(): boolean { return true } @@ -52,17 +52,17 @@ export class WorkflowVariableBlockNode extends DecoratorNode this.__availableVariables = availableVariables } - createDOM(): HTMLElement { + override createDOM(): HTMLElement { const div = document.createElement('div') div.classList.add('inline-flex', 'items-center', 'align-middle') return div } - updateDOM(): false { + override updateDOM(): false { return false } - decorate(): React.JSX.Element { + override decorate(): React.JSX.Element { return ( ) } - static importJSON(serializedNode: SerializedNode): WorkflowVariableBlockNode { + static override importJSON(serializedNode: SerializedNode): WorkflowVariableBlockNode { const node = $createWorkflowVariableBlockNode( serializedNode.variables, serializedNode.workflowNodesMap, @@ -85,7 +85,7 @@ export class WorkflowVariableBlockNode extends DecoratorNode return node } - exportJSON(): SerializedNode { + override exportJSON(): SerializedNode { const json: SerializedNode = { type: 'workflow-variable-block', version: 1, @@ -119,7 +119,7 @@ export class WorkflowVariableBlockNode extends DecoratorNode return self.__availableVariables } - getTextContent(): string { + override getTextContent(): string { return `{{#${this.getVariables().join('.')}#}}` } } diff --git a/web/app/components/base/prompt-log-modal/index.tsx b/web/app/components/base/prompt-log-modal/index.tsx index 6a79dfffeb..08200623ae 100644 --- a/web/app/components/base/prompt-log-modal/index.tsx +++ b/web/app/components/base/prompt-log-modal/index.tsx @@ -42,13 +42,13 @@ const PromptLogModal: FC = ({ }} ref={ref} > -
+
PROMPT LOG
{ currentLogItem.log?.length === 1 && ( <> - +
) diff --git a/web/app/components/base/select/locale-signin.tsx b/web/app/components/base/select/locale-signin.tsx index 3c5dd999f6..046a76a5d4 100644 --- a/web/app/components/base/select/locale-signin.tsx +++ b/web/app/components/base/select/locale-signin.tsx @@ -36,7 +36,7 @@ export default function LocaleSigninSelect({ leaveTo="transform opacity-0 scale-95" > -
+
{items.map((item) => { return ( diff --git a/web/app/components/base/spinner/index.tsx b/web/app/components/base/spinner/index.tsx index 65fea46a91..48ee65b99f 100644 --- a/web/app/components/base/spinner/index.tsx +++ b/web/app/components/base/spinner/index.tsx @@ -14,7 +14,7 @@ const Spinner: FC = ({ loading = false, children, className }) => { role="status" > Loading... diff --git a/web/app/components/base/tag-management/index.tsx b/web/app/components/base/tag-management/index.tsx index 79c557a8b9..8e693fb9f1 100644 --- a/web/app/components/base/tag-management/index.tsx +++ b/web/app/components/base/tag-management/index.tsx @@ -48,8 +48,8 @@ const TagManagementModal = ({ show, type }: TagManagementModalProps) => { }, [type]) return ( setShowTagManagementModal(false)}> -
{t('tag.manageTags', { ns: 'common' })}
-
setShowTagManagementModal(false)}> +
{t('tag.manageTags', { ns: 'common' })}
+
setShowTagManagementModal(false)}>
diff --git a/web/app/components/base/tag-management/panel.tsx b/web/app/components/base/tag-management/panel.tsx index cceb09b4d7..db0aae1b05 100644 --- a/web/app/components/base/tag-management/panel.tsx +++ b/web/app/components/base/tag-management/panel.tsx @@ -114,7 +114,7 @@ const Panel = (props: PanelProps) => {
-
+
{`${t('tag.create', { ns: 'common' })} `} {`'${keywords}'`}
@@ -127,7 +127,7 @@ const Panel = (props: PanelProps) => { {filteredSelectedTagList.map(tag => (
selectTag(tag)} data-testid="tag-row"> -
+
{tag.name}
@@ -135,7 +135,7 @@ const Panel = (props: PanelProps) => { {filteredTagList.map(tag => (
selectTag(tag)} data-testid="tag-row"> -
+
{tag.name}
@@ -146,7 +146,7 @@ const Panel = (props: PanelProps) => {
-
{t('tag.noTag', { ns: 'common' })}
+
{t('tag.noTag', { ns: 'common' })}
)} @@ -154,7 +154,7 @@ const Panel = (props: PanelProps) => {
setShowTagManagementModal(true)}> -
+
{t('tag.manageTags', { ns: 'common' })}
diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/__tests__/index.spec.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/__tests__/index.spec.tsx index 0ae553ec01..a4b8888b27 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/__tests__/index.spec.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/__tests__/index.spec.tsx @@ -98,10 +98,10 @@ describe('CloudPlanItem', () => { />, ) - expect(screen.getByText('billing.plans.sandbox.name')).toBeInTheDocument() - expect(screen.getByText('billing.plans.sandbox.description')).toBeInTheDocument() - expect(screen.getByText('billing.plansCommon.free')).toBeInTheDocument() - expect(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' })).toBeInTheDocument() + expect(screen.getByText('billing.plans.sandbox.name'))!.toBeInTheDocument() + expect(screen.getByText('billing.plans.sandbox.description'))!.toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.free'))!.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' }))!.toBeInTheDocument() }) it('should display yearly pricing with discount when planRange is yearly', () => { @@ -115,9 +115,9 @@ describe('CloudPlanItem', () => { ) const professionalPlan = ALL_PLANS[Plan.professional] - expect(screen.getByText(`$${professionalPlan.price * 12}`)).toBeInTheDocument() - expect(screen.getByText(`$${professionalPlan.price * 10}`)).toBeInTheDocument() - expect(screen.getByText(/billing\.plansCommon\.priceTip.*billing\.plansCommon\.year/)).toBeInTheDocument() + expect(screen.getByText(`$${professionalPlan.price * 12}`))!.toBeInTheDocument() + expect(screen.getByText(`$${professionalPlan.price * 10}`))!.toBeInTheDocument() + expect(screen.getByText(/billing\.plansCommon\.priceTip.*billing\.plansCommon\.year/))!.toBeInTheDocument() }) it('should show "most popular" badge for professional plan', () => { @@ -130,7 +130,7 @@ describe('CloudPlanItem', () => { />, ) - expect(screen.getByText('billing.plansCommon.mostPopular')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.mostPopular'))!.toBeInTheDocument() }) it('should not show "most popular" badge for non-professional plans', () => { @@ -157,7 +157,7 @@ describe('CloudPlanItem', () => { ) const button = screen.getByRole('button', { name: 'billing.plansCommon.startBuilding' }) - expect(button).toBeDisabled() + expect(button)!.toBeDisabled() }) }) @@ -176,7 +176,7 @@ describe('CloudPlanItem', () => { ) fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.startBuilding' })) - expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() + expect(screen.getByText('billing.buyPermissionDeniedTip'))!.toBeInTheDocument() expect(mockBillingInvoices).not.toHaveBeenCalled() }) @@ -320,7 +320,7 @@ describe('CloudPlanItem', () => { expect(openWindow).toHaveBeenCalledTimes(1) // The onError callback should have been passed to openAsyncWindow const callArgs = openWindow.mock.calls[0] - expect(callArgs[1]).toHaveProperty('onError') + expect(callArgs![1]).toHaveProperty('onError') }) }) @@ -336,8 +336,39 @@ describe('CloudPlanItem', () => { ) const teamPlan = ALL_PLANS[Plan.team] - expect(screen.getByText(`$${teamPlan.price}`)).toBeInTheDocument() - expect(screen.getByText(/billing\.plansCommon\.priceTip.*billing\.plansCommon\.month/)).toBeInTheDocument() + expect(screen.getByText(`$${teamPlan.price}`))!.toBeInTheDocument() + expect(screen.getByText(/billing\.plansCommon\.priceTip.*billing\.plansCommon\.month/))!.toBeInTheDocument() + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price + // Should NOT show crossed-out yearly price // Should NOT show crossed-out yearly price expect(screen.queryByText(`$${teamPlan.price * 12}`)).not.toBeInTheDocument() }) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx index b85f1d8631..99d956ba90 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx @@ -103,38 +103,38 @@ const CloudPlanItem: FC = ({ {ICON_MAP[plan]}
-
{t(`${i18nPrefix}.name`, { ns: 'billing' })}
+
{t(`${i18nPrefix}.name`, { ns: 'billing' })}
{ isMostPopularPlan && (
- + {t('plansCommon.mostPopular', { ns: 'billing' })}
) }
-
{t(`${i18nPrefix}.description`, { ns: 'billing' })}
+
{t(`${i18nPrefix}.description`, { ns: 'billing' })}
{/* Price */} -
+
{isFreePlan && ( - {t('plansCommon.free', { ns: 'billing' })} + {t('plansCommon.free', { ns: 'billing' })} )} {!isFreePlan && ( <> {isYear && ( - + $ {planInfo.price * 12} )} - + $ {isYear ? planInfo.price * 10 : planInfo.price} - + {t('plansCommon.priceTip', { ns: 'billing' })} {t(`plansCommon.${!isYear ? 'month' : 'year'}`, { ns: 'billing' })} diff --git a/web/app/components/datasets/common/image-previewer/index.tsx b/web/app/components/datasets/common/image-previewer/index.tsx index f899dcb33d..2cc51a3be5 100644 --- a/web/app/components/datasets/common/image-previewer/index.tsx +++ b/web/app/components/datasets/common/image-previewer/index.tsx @@ -137,7 +137,7 @@ const ImagePreviewer = ({ return { ...prev, [image.url]: { - ...prev[image.url], + ...prev[image.url]!, status: 'loading', }, } @@ -168,15 +168,15 @@ const ImagePreviewer = ({ Esc
- {cachedImages[currentImage.url].status === 'loading' && ( + {cachedImages[currentImage!.url]!.status === 'loading' && ( )} - {cachedImages[currentImage.url].status === 'error' && ( + {cachedImages[currentImage!.url]!.status === 'error' && (
- {`Failed to load image: ${currentImage.url}. Please try again.`} + {`Failed to load image: ${currentImage!.url}. Please try again.`}
)} - {cachedImages[currentImage.url].status === 'loaded' && ( + {cachedImages[currentImage!.url]!.status === 'loaded' && (
{currentImage.name}
- {currentImage.name} + {currentImage!.name} · - {`${cachedImages[currentImage.url].width} ×  ${cachedImages[currentImage.url].height}`} + {`${cachedImages[currentImage!.url]!.width} ×  ${cachedImages[currentImage!.url]!.height}`} · - {formatFileSize(currentImage.size)} + {formatFileSize(currentImage!.size)}
)} diff --git a/web/app/components/datasets/create/step-one/upgrade-card.tsx b/web/app/components/datasets/create/step-one/upgrade-card.tsx index 356e15ed43..e7016206ea 100644 --- a/web/app/components/datasets/create/step-one/upgrade-card.tsx +++ b/web/app/components/datasets/create/step-one/upgrade-card.tsx @@ -15,9 +15,9 @@ const UpgradeCard: FC = () => { }, [setShowPricingModal]) return ( -
+
-
{t('upgrade.uploadMultipleFiles.title', { ns: 'billing' })}
+
{t('upgrade.uploadMultipleFiles.title', { ns: 'billing' })}
{t('upgrade.uploadMultipleFiles.description', { ns: 'billing' })}
{ render() // Should render Previous and Next buttons with correct text - expect(screen.getByText(/previousStep/i)).toBeInTheDocument() - expect(screen.getByText(/nextStep/i)).toBeInTheDocument() + // Should render Previous and Next buttons with correct text + expect(screen.getByText(/previousStep/i))!.toBeInTheDocument() + expect(screen.getByText(/nextStep/i))!.toBeInTheDocument() }) it('should render Previous and Next buttons when not in setting mode', () => { render() - expect(screen.getByText(/previousStep/i)).toBeInTheDocument() - expect(screen.getByText(/nextStep/i)).toBeInTheDocument() + expect(screen.getByText(/previousStep/i))!.toBeInTheDocument() + expect(screen.getByText(/nextStep/i))!.toBeInTheDocument() }) it('should render Save and Cancel buttons when in setting mode', () => { render() - expect(screen.getByText(/save/i)).toBeInTheDocument() - expect(screen.getByText(/cancel/i)).toBeInTheDocument() + expect(screen.getByText(/save/i))!.toBeInTheDocument() + expect(screen.getByText(/cancel/i))!.toBeInTheDocument() }) }) @@ -1772,14 +1773,14 @@ describe('StepTwoFooter', () => { render() const nextButton = screen.getByText(/nextStep/i).closest('button') - expect(nextButton).toBeDisabled() + expect(nextButton)!.toBeDisabled() }) it('should show loading state on Save button when creating in setting mode', () => { render() const saveButton = screen.getByText(/save/i).closest('button') - expect(saveButton).toBeDisabled() + expect(saveButton)!.toBeDisabled() }) }) }) @@ -1811,18 +1812,50 @@ describe('PreviewPanel', () => { render() // Check for the preview header title text - expect(screen.getByText('datasetCreation.stepTwo.preview')).toBeInTheDocument() + // Check for the preview header title text + expect(screen.getByText('datasetCreation.stepTwo.preview'))!.toBeInTheDocument() }) it('should render idle state when isIdle is true', () => { render() - expect(screen.getByText(/previewChunkTip/i)).toBeInTheDocument() + expect(screen.getByText(/previewChunkTip/i))!.toBeInTheDocument() }) it('should render loading skeleton when isPending is true', () => { render() + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers + // Should show skeleton containers // Should show skeleton containers expect(screen.queryByText(/previewChunkTip/i)).not.toBeInTheDocument() }) @@ -1841,7 +1874,7 @@ describe('PreviewPanel', () => { />, ) - expect(screen.getByText('Chunk 1 content')).toBeInTheDocument() + expect(screen.getByText('Chunk 1 content'))!.toBeInTheDocument() }) it('should render QA preview when docForm is qa', () => { @@ -1855,8 +1888,8 @@ describe('PreviewPanel', () => { />, ) - expect(screen.getByText('Q1')).toBeInTheDocument() - expect(screen.getByText('A1')).toBeInTheDocument() + expect(screen.getByText('Q1'))!.toBeInTheDocument() + expect(screen.getByText('A1'))!.toBeInTheDocument() }) it('should show chunk count badge for non-QA doc form', () => { @@ -1870,7 +1903,7 @@ describe('PreviewPanel', () => { />, ) - expect(screen.getByText(/25/)).toBeInTheDocument() + expect(screen.getByText(/25/))!.toBeInTheDocument() }) it('should render parent-child preview when docForm is parentChild', () => { @@ -1893,11 +1926,13 @@ describe('PreviewPanel', () => { ) // Should render parent chunk label - expect(screen.getByText('Chunk-1')).toBeInTheDocument() + // Should render parent chunk label + expect(screen.getByText('Chunk-1'))!.toBeInTheDocument() // Should render child chunks - expect(screen.getByText('Child 1')).toBeInTheDocument() - expect(screen.getByText('Child 2')).toBeInTheDocument() - expect(screen.getByText('Child 3')).toBeInTheDocument() + // Should render child chunks + expect(screen.getByText('Child 1'))!.toBeInTheDocument() + expect(screen.getByText('Child 2'))!.toBeInTheDocument() + expect(screen.getByText('Child 3'))!.toBeInTheDocument() }) it('should limit child chunks when chunkForContext is full-doc', () => { @@ -1920,10 +1955,43 @@ describe('PreviewPanel', () => { ) // Should render parent chunk - expect(screen.getByText('Chunk-1')).toBeInTheDocument() + // Should render parent chunk + expect(screen.getByText('Chunk-1'))!.toBeInTheDocument() // full-doc mode limits to FULL_DOC_PREVIEW_LENGTH (50) - expect(screen.getByText('ChildChunk1')).toBeInTheDocument() - expect(screen.getByText('ChildChunk50')).toBeInTheDocument() + // full-doc mode limits to FULL_DOC_PREVIEW_LENGTH (50) + expect(screen.getByText('ChildChunk1'))!.toBeInTheDocument() + expect(screen.getByText('ChildChunk50'))!.toBeInTheDocument() + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit + // Should not render beyond the limit // Should not render beyond the limit expect(screen.queryByText('ChildChunk51')).not.toBeInTheDocument() }) @@ -1944,10 +2012,10 @@ describe('PreviewPanel', () => { />, ) - expect(screen.getByText('Chunk-1')).toBeInTheDocument() - expect(screen.getByText('Chunk-2')).toBeInTheDocument() - expect(screen.getByText('P1-C1')).toBeInTheDocument() - expect(screen.getByText('P2-C1')).toBeInTheDocument() + expect(screen.getByText('Chunk-1'))!.toBeInTheDocument() + expect(screen.getByText('Chunk-2'))!.toBeInTheDocument() + expect(screen.getByText('P1-C1'))!.toBeInTheDocument() + expect(screen.getByText('P2-C1'))!.toBeInTheDocument() }) }) @@ -2222,19 +2290,20 @@ describe('StepTwo Component', () => { describe('Rendering', () => { it('should render without crashing', () => { render() - expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() }) it('should show general chunking options when not in upload', () => { render() // Should render the segmentation section - expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() + // Should render the segmentation section + expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() }) it('should show footer with Previous and Next buttons', () => { render() - expect(screen.getByText(/stepTwo\.previousStep/i)).toBeInTheDocument() - expect(screen.getByText(/stepTwo\.nextStep/i)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.previousStep/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.nextStep/i))!.toBeInTheDocument() }) }) @@ -2282,7 +2351,7 @@ describe('StepTwo Component', () => { render() // GeneralChunkingOptions renders a "Preview Chunk" button const previewButtons = screen.getAllByText(/stepTwo\.previewChunk/i) - fireEvent.click(previewButtons[0]) + fireEvent.click(previewButtons[0]!) // updatePreview calls estimateHook.fetchEstimate() // No error means the handler executed successfully }) @@ -2292,7 +2361,7 @@ describe('StepTwo Component', () => { // ParentChildOptions renders an OptionCard; find the title element and click its parent card const parentChildTitles = screen.getAllByText(/stepTwo\.parentChild/i) // The first match is the title; click it to trigger onDocFormChange - fireEvent.click(parentChildTitles[0]) + fireEvent.click(parentChildTitles[0]!) // handleDocFormChange sets docForm, segmentationType, and resets estimate }) }) @@ -2307,7 +2376,8 @@ describe('StepTwo Component', () => { />, ) // When currentDataset has parentChild doc_form, should show parent-child option - expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() + // When currentDataset has parentChild doc_form, should show parent-child option + expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() }) it('should render setting mode with Save/Cancel buttons', () => { @@ -2320,8 +2390,8 @@ describe('StepTwo Component', () => { datasetId="test-id" />, ) - expect(screen.getByText(/stepTwo\.save/i)).toBeInTheDocument() - expect(screen.getByText(/stepTwo\.cancel/i)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.save/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.cancel/i))!.toBeInTheDocument() }) it('should call onCancel when Cancel button is clicked in setting mode', () => { @@ -2362,8 +2432,9 @@ describe('StepTwo Component', () => { it('should show both general and parent-child options in create page', () => { render() // When isInInit (no datasetId, no isSetting), both options should show - expect(screen.getByText('datasetCreation.stepTwo.general')).toBeInTheDocument() - expect(screen.getByText('datasetCreation.stepTwo.parentChild')).toBeInTheDocument() + // When isInInit (no datasetId, no isSetting), both options should show + expect(screen.getByText('datasetCreation.stepTwo.general'))!.toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.parentChild'))!.toBeInTheDocument() }) it('should only show parent-child option when dataset has parentChild doc_form', () => { @@ -2376,7 +2447,9 @@ describe('StepTwo Component', () => { ) // showGeneralOption should be false (parentChild not in [text, qa]) // showParentChildOption should be true - expect(screen.getByText('datasetCreation.stepTwo.parentChild')).toBeInTheDocument() + // showGeneralOption should be false (parentChild not in [text, qa]) + // showParentChildOption should be true + expect(screen.getByText('datasetCreation.stepTwo.parentChild'))!.toBeInTheDocument() }) it('should show general option only when dataset has text doc_form', () => { @@ -2388,7 +2461,8 @@ describe('StepTwo Component', () => { />, ) // showGeneralOption should be true (text is in [text, qa]) - expect(screen.getByText('datasetCreation.stepTwo.general')).toBeInTheDocument() + // showGeneralOption should be true (text is in [text, qa]) + expect(screen.getByText('datasetCreation.stepTwo.general'))!.toBeInTheDocument() }) }) @@ -2401,7 +2475,7 @@ describe('StepTwo Component', () => { datasetId="test-id" />, ) - expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() }) it('should show general option for empty dataset (no doc_form)', () => { @@ -2413,7 +2487,7 @@ describe('StepTwo Component', () => { datasetId="test-id" />, ) - expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() }) it('should show both options in empty dataset upload', () => { @@ -2426,8 +2500,9 @@ describe('StepTwo Component', () => { />, ) // isUploadInEmptyDataset=true shows both options - expect(screen.getByText('datasetCreation.stepTwo.general')).toBeInTheDocument() - expect(screen.getByText('datasetCreation.stepTwo.parentChild')).toBeInTheDocument() + // isUploadInEmptyDataset=true shows both options + expect(screen.getByText('datasetCreation.stepTwo.general'))!.toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.parentChild'))!.toBeInTheDocument() }) }) @@ -2435,19 +2510,22 @@ describe('StepTwo Component', () => { it('should render indexing mode section', () => { render() // IndexingModeSection renders the index mode title - expect(screen.getByText(/stepTwo\.indexMode/i)).toBeInTheDocument() + // IndexingModeSection renders the index mode title + expect(screen.getByText(/stepTwo\.indexMode/i))!.toBeInTheDocument() }) it('should render embedding model selector when QUALIFIED', () => { render() // ModelSelector is mocked and rendered with data-testid - expect(screen.getByTestId('model-selector')).toBeInTheDocument() + // ModelSelector is mocked and rendered with data-testid + expect(screen.getByTestId('model-selector'))!.toBeInTheDocument() }) it('should render retrieval method config', () => { render() // RetrievalMethodConfig is mocked with data-testid - expect(screen.getByTestId('retrieval-method-config')).toBeInTheDocument() + // RetrievalMethodConfig is mocked with data-testid + expect(screen.getByTestId('retrieval-method-config'))!.toBeInTheDocument() }) it('should disable model and retrieval config when datasetId has existing data source', () => { @@ -2460,14 +2538,14 @@ describe('StepTwo Component', () => { ) // isModelAndRetrievalConfigDisabled should be true const modelSelector = screen.getByTestId('model-selector') - expect(modelSelector).toHaveAttribute('data-readonly', 'true') + expect(modelSelector)!.toHaveAttribute('data-readonly', 'true') }) }) describe('Preview Panel', () => { it('should render preview panel', () => { render() - expect(screen.getByText('datasetCreation.stepTwo.preview')).toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.preview'))!.toBeInTheDocument() }) it('should hide document picker in setting mode', () => { @@ -2481,7 +2559,8 @@ describe('StepTwo Component', () => { />, ) // Preview panel should still render - expect(screen.getByText('datasetCreation.stepTwo.preview')).toBeInTheDocument() + // Preview panel should still render + expect(screen.getByText('datasetCreation.stepTwo.preview'))!.toBeInTheDocument() }) }) @@ -2498,35 +2577,35 @@ describe('StepTwo Component', () => { it('should switch to QUALIFIED when selecting parentChild in ECONOMICAL mode', async () => { render() await vi.waitFor(() => { - expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() }) const parentChildTitles = screen.getAllByText(/stepTwo\.parentChild/i) - fireEvent.click(parentChildTitles[0]) + fireEvent.click(parentChildTitles[0]!) }) it('should open QA confirm dialog and confirm switch when QA selected in ECONOMICAL mode', async () => { render() await vi.waitFor(() => { - expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() }) const qaCheckbox = screen.getByText(/stepTwo\.useQALanguage/i) fireEvent.click(qaCheckbox) // Dialog should open → click Switch to confirm (triggers handleQAConfirm) const switchButton = await screen.findByText(/stepTwo\.switch/i) - expect(switchButton).toBeInTheDocument() + expect(switchButton)!.toBeInTheDocument() fireEvent.click(switchButton) }) it('should close QA confirm dialog when cancel is clicked', async () => { render() await vi.waitFor(() => { - expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() }) // Open QA confirm dialog const qaCheckbox = screen.getByText(/stepTwo\.useQALanguage/i) fireEvent.click(qaCheckbox) const dialogCancelButtons = await screen.findAllByText(/stepTwo\.cancel/i) - fireEvent.click(dialogCancelButtons[0]) + fireEvent.click(dialogCancelButtons[0]!) }) it('should handle picker change when selecting a different file', () => { @@ -2545,7 +2624,7 @@ describe('StepTwo Component', () => { render() // The default maxChunkLength (1024) now exceeds the limit (100) const previewButtons = screen.getAllByText(/stepTwo\.previewChunk/i) - fireEvent.click(previewButtons[0]) + fireEvent.click(previewButtons[0]!) // Restore document.body.removeAttribute('data-public-indexing-max-segmentation-tokens-length') }) diff --git a/web/app/components/datasets/create/step-two/components/__tests__/inputs.spec.tsx b/web/app/components/datasets/create/step-two/components/__tests__/inputs.spec.tsx index 2c0480e508..f1ab5392ce 100644 --- a/web/app/components/datasets/create/step-two/components/__tests__/inputs.spec.tsx +++ b/web/app/components/datasets/create/step-two/components/__tests__/inputs.spec.tsx @@ -12,26 +12,27 @@ describe('DelimiterInput', () => { it('should render separator label', () => { render() - expect(screen.getByText(`${ns}.stepTwo.separator`)).toBeInTheDocument() + expect(screen.getByText(`${ns}.stepTwo.separator`))!.toBeInTheDocument() }) it('should render text input with placeholder', () => { render() const input = screen.getByPlaceholderText(`${ns}.stepTwo.separatorPlaceholder`) - expect(input).toBeInTheDocument() - expect(input).toHaveAttribute('type', 'text') + expect(input)!.toBeInTheDocument() + expect(input)!.toHaveAttribute('type', 'text') }) it('should pass through value and onChange props', () => { const onChange = vi.fn() render() - expect(screen.getByDisplayValue('test-val')).toBeInTheDocument() + expect(screen.getByDisplayValue('test-val'))!.toBeInTheDocument() }) it('should render tooltip content', () => { render() // Tooltip triggers render; component mounts without error - expect(screen.getByText(`${ns}.stepTwo.separator`)).toBeInTheDocument() + // Tooltip triggers render; component mounts without error + expect(screen.getByText(`${ns}.stepTwo.separator`))!.toBeInTheDocument() }) it('should suppress onChange during IME composition', () => { @@ -47,7 +48,7 @@ describe('DelimiterInput', () => { fireEvent.compositionEnd(input) expect(onChange).toHaveBeenCalledTimes(1) - expect(onChange.mock.calls[0][0].target.value).toBe(finalValue) + expect(onChange.mock.calls[0]![0].target.value).toBe(finalValue) }) }) @@ -58,24 +59,24 @@ describe('MaxLengthInput', () => { it('should render max length label', () => { render() - expect(screen.getByText(`${ns}.stepTwo.maxLength`)).toBeInTheDocument() + expect(screen.getByText(`${ns}.stepTwo.maxLength`))!.toBeInTheDocument() }) it('should render number input', () => { render() const input = screen.getByRole('textbox') - expect(input).toBeInTheDocument() + expect(input)!.toBeInTheDocument() }) it('should accept value prop', () => { render() - expect(screen.getByRole('textbox')).toHaveValue('500') + expect(screen.getByRole('textbox'))!.toHaveValue('500') }) it('should have min of 1', () => { render() const input = screen.getByRole('textbox') - expect(input).toBeInTheDocument() + expect(input)!.toBeInTheDocument() }) it('should reset to the minimum when users clear the value', () => { @@ -107,18 +108,18 @@ describe('OverlapInput', () => { it('should render number input', () => { render() const input = screen.getByRole('textbox') - expect(input).toBeInTheDocument() + expect(input)!.toBeInTheDocument() }) it('should accept value prop', () => { render() - expect(screen.getByRole('textbox')).toHaveValue('50') + expect(screen.getByRole('textbox'))!.toHaveValue('50') }) it('should have min of 1', () => { render() const input = screen.getByRole('textbox') - expect(input).toBeInTheDocument() + expect(input)!.toBeInTheDocument() }) it('should reset to the minimum when users clear the value', () => { diff --git a/web/app/components/datasets/create/website/firecrawl/__tests__/options.spec.tsx b/web/app/components/datasets/create/website/firecrawl/__tests__/options.spec.tsx index 313ad9c051..946c04aa93 100644 --- a/web/app/components/datasets/create/website/firecrawl/__tests__/options.spec.tsx +++ b/web/app/components/datasets/create/website/firecrawl/__tests__/options.spec.tsx @@ -35,9 +35,10 @@ describe('Options', () => { render() // Check that key elements are rendered - expect(screen.getByText(/crawlSubPage/i)).toBeInTheDocument() - expect(screen.getByText(/limit/i)).toBeInTheDocument() - expect(screen.getByText(/maxDepth/i)).toBeInTheDocument() + // Check that key elements are rendered + expect(screen.getByText(/crawlSubPage/i))!.toBeInTheDocument() + expect(screen.getByText(/limit/i))!.toBeInTheDocument() + expect(screen.getByText(/maxDepth/i))!.toBeInTheDocument() }) it('should render all form fields', () => { @@ -45,14 +46,16 @@ describe('Options', () => { render() // Checkboxes - expect(screen.getByText(/crawlSubPage/i)).toBeInTheDocument() - expect(screen.getByText(/extractOnlyMainContent/i)).toBeInTheDocument() + // Checkboxes + expect(screen.getByText(/crawlSubPage/i))!.toBeInTheDocument() + expect(screen.getByText(/extractOnlyMainContent/i))!.toBeInTheDocument() // Text/Number fields - expect(screen.getByText(/limit/i)).toBeInTheDocument() - expect(screen.getByText(/maxDepth/i)).toBeInTheDocument() - expect(screen.getByText(/excludePaths/i)).toBeInTheDocument() - expect(screen.getByText(/includeOnlyPaths/i)).toBeInTheDocument() + // Text/Number fields + expect(screen.getByText(/limit/i))!.toBeInTheDocument() + expect(screen.getByText(/maxDepth/i))!.toBeInTheDocument() + expect(screen.getByText(/excludePaths/i))!.toBeInTheDocument() + expect(screen.getByText(/includeOnlyPaths/i))!.toBeInTheDocument() }) it('should render with custom className', () => { @@ -62,7 +65,7 @@ describe('Options', () => { ) const rootElement = container.firstChild as HTMLElement - expect(rootElement).toHaveClass('custom-class') + expect(rootElement)!.toHaveClass('custom-class') }) it('should render limit field with required indicator', () => { @@ -71,7 +74,7 @@ describe('Options', () => { // Limit field should have required indicator (*) const requiredIndicator = screen.getByText('*') - expect(requiredIndicator).toBeInTheDocument() + expect(requiredIndicator)!.toBeInTheDocument() }) it('should render placeholder for excludes field', () => { @@ -79,7 +82,7 @@ describe('Options', () => { render() const excludesInput = screen.getByPlaceholderText('blog/*, /about/*') - expect(excludesInput).toBeInTheDocument() + expect(excludesInput)!.toBeInTheDocument() }) it('should render placeholder for includes field', () => { @@ -87,7 +90,7 @@ describe('Options', () => { render() const includesInput = screen.getByPlaceholderText('articles/*') - expect(includesInput).toBeInTheDocument() + expect(includesInput)!.toBeInTheDocument() }) it('should render two checkboxes', () => { @@ -106,7 +109,8 @@ describe('Options', () => { render() // First checkbox should have check icon when checked - expect(screen.queryByTestId('check-icon-crawl-sub-page')).toBeInTheDocument() + // First checkbox should have check icon when checked + expect(screen.queryByTestId('check-icon-crawl-sub-page'))!.toBeInTheDocument() }) it('should display crawl_sub_pages checkbox without check icon when false', () => { @@ -118,7 +122,7 @@ describe('Options', () => { it('should display only_main_content checkbox with check icon when true', () => { const payload = createMockCrawlOptions({ only_main_content: true }) render() - expect(screen.getByTestId('check-icon-only-main-content')).toBeInTheDocument() + expect(screen.getByTestId('check-icon-only-main-content'))!.toBeInTheDocument() }) it('should display only_main_content checkbox without check icon when false', () => { @@ -132,7 +136,7 @@ describe('Options', () => { render() const limitInput = screen.getByDisplayValue('25') - expect(limitInput).toBeInTheDocument() + expect(limitInput)!.toBeInTheDocument() }) it('should display max_depth value in input', () => { @@ -140,7 +144,7 @@ describe('Options', () => { render() const maxDepthInput = screen.getByDisplayValue('5') - expect(maxDepthInput).toBeInTheDocument() + expect(maxDepthInput)!.toBeInTheDocument() }) it('should display excludes value in input', () => { @@ -148,7 +152,7 @@ describe('Options', () => { render() const excludesInput = screen.getByDisplayValue('test/*') - expect(excludesInput).toBeInTheDocument() + expect(excludesInput)!.toBeInTheDocument() }) it('should display includes value in input', () => { @@ -156,7 +160,7 @@ describe('Options', () => { render() const includesInput = screen.getByDisplayValue('docs/*') - expect(includesInput).toBeInTheDocument() + expect(includesInput)!.toBeInTheDocument() }) }) @@ -166,7 +170,7 @@ describe('Options', () => { const { container } = render() const checkboxes = getCheckboxes(container) - fireEvent.click(checkboxes[0]) + fireEvent.click(checkboxes[0]!) expect(mockOnChange).toHaveBeenCalledWith({ ...payload, @@ -179,7 +183,7 @@ describe('Options', () => { const { container } = render() const checkboxes = getCheckboxes(container) - fireEvent.click(checkboxes[1]) + fireEvent.click(checkboxes[1]!) expect(mockOnChange).toHaveBeenCalledWith({ ...payload, @@ -251,7 +255,8 @@ describe('Options', () => { render() // Component should render without crashing - expect(screen.getByText(/limit/i)).toBeInTheDocument() + // Component should render without crashing + expect(screen.getByText(/limit/i))!.toBeInTheDocument() }) it('should handle zero values', () => { @@ -273,8 +278,8 @@ describe('Options', () => { }) render() - expect(screen.getByDisplayValue('9999')).toBeInTheDocument() - expect(screen.getByDisplayValue('100')).toBeInTheDocument() + expect(screen.getByDisplayValue('9999'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('100'))!.toBeInTheDocument() }) it('should handle special characters in text fields', () => { @@ -284,8 +289,8 @@ describe('Options', () => { }) render() - expect(screen.getByDisplayValue('path/*/file?query=1¶m=2')).toBeInTheDocument() - expect(screen.getByDisplayValue('docs/**/*.md')).toBeInTheDocument() + expect(screen.getByDisplayValue('path/*/file?query=1¶m=2'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('docs/**/*.md'))!.toBeInTheDocument() }) it('should preserve other payload fields when updating one field', () => { @@ -357,7 +362,7 @@ describe('Options', () => { rerender() - expect(screen.getByText(/limit/i)).toBeInTheDocument() + expect(screen.getByText(/limit/i))!.toBeInTheDocument() }) it('should re-render when payload changes', () => { @@ -365,10 +370,10 @@ describe('Options', () => { const payload2 = createMockCrawlOptions({ limit: 20 }) const { rerender } = render() - expect(screen.getByDisplayValue('10')).toBeInTheDocument() + expect(screen.getByDisplayValue('10'))!.toBeInTheDocument() rerender() - expect(screen.getByDisplayValue('20')).toBeInTheDocument() + expect(screen.getByDisplayValue('20'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/create/website/watercrawl/__tests__/options.spec.tsx b/web/app/components/datasets/create/website/watercrawl/__tests__/options.spec.tsx index bda01dc152..b65df109ad 100644 --- a/web/app/components/datasets/create/website/watercrawl/__tests__/options.spec.tsx +++ b/web/app/components/datasets/create/website/watercrawl/__tests__/options.spec.tsx @@ -34,12 +34,12 @@ describe('Options (watercrawl)', () => { const payload = createMockCrawlOptions() render() - expect(screen.getByText(/crawlSubPage/i)).toBeInTheDocument() - expect(screen.getByText(/extractOnlyMainContent/i)).toBeInTheDocument() - expect(screen.getByText(/limit/i)).toBeInTheDocument() - expect(screen.getByText(/maxDepth/i)).toBeInTheDocument() - expect(screen.getByText(/excludePaths/i)).toBeInTheDocument() - expect(screen.getByText(/includeOnlyPaths/i)).toBeInTheDocument() + expect(screen.getByText(/crawlSubPage/i))!.toBeInTheDocument() + expect(screen.getByText(/extractOnlyMainContent/i))!.toBeInTheDocument() + expect(screen.getByText(/limit/i))!.toBeInTheDocument() + expect(screen.getByText(/maxDepth/i))!.toBeInTheDocument() + expect(screen.getByText(/excludePaths/i))!.toBeInTheDocument() + expect(screen.getByText(/includeOnlyPaths/i))!.toBeInTheDocument() }) it('should render two checkboxes', () => { @@ -55,21 +55,21 @@ describe('Options (watercrawl)', () => { render() const requiredIndicator = screen.getByText('*') - expect(requiredIndicator).toBeInTheDocument() + expect(requiredIndicator)!.toBeInTheDocument() }) it('should render placeholder for excludes field', () => { const payload = createMockCrawlOptions() render() - expect(screen.getByPlaceholderText('blog/*, /about/*')).toBeInTheDocument() + expect(screen.getByPlaceholderText('blog/*, /about/*'))!.toBeInTheDocument() }) it('should render placeholder for includes field', () => { const payload = createMockCrawlOptions() render() - expect(screen.getByPlaceholderText('articles/*')).toBeInTheDocument() + expect(screen.getByPlaceholderText('articles/*'))!.toBeInTheDocument() }) it('should render with custom className', () => { @@ -79,7 +79,7 @@ describe('Options (watercrawl)', () => { ) const rootElement = container.firstChild as HTMLElement - expect(rootElement).toHaveClass('custom-class') + expect(rootElement)!.toHaveClass('custom-class') }) }) @@ -89,7 +89,7 @@ describe('Options (watercrawl)', () => { const payload = createMockCrawlOptions({ crawl_sub_pages: true }) render() - expect(screen.getByTestId('check-icon-crawl-sub-pages')).toBeInTheDocument() + expect(screen.getByTestId('check-icon-crawl-sub-pages'))!.toBeInTheDocument() }) it('should display crawl_sub_pages checkbox without check icon when false', () => { @@ -97,13 +97,13 @@ describe('Options (watercrawl)', () => { const { container } = render() const checkboxes = getCheckboxes(container) - expect(checkboxes[0].querySelector('svg')).not.toBeInTheDocument() + expect(checkboxes[0]!.querySelector('svg')).not.toBeInTheDocument() }) it('should display only_main_content checkbox with check icon when true', () => { const payload = createMockCrawlOptions({ only_main_content: true }) render() - expect(screen.getByTestId('check-icon-only-main-content')).toBeInTheDocument() + expect(screen.getByTestId('check-icon-only-main-content'))!.toBeInTheDocument() }) it('should display only_main_content checkbox without check icon when false', () => { @@ -111,35 +111,35 @@ describe('Options (watercrawl)', () => { const { container } = render() const checkboxes = getCheckboxes(container) - expect(checkboxes[1].querySelector('svg')).not.toBeInTheDocument() + expect(checkboxes[1]!.querySelector('svg')).not.toBeInTheDocument() }) it('should display limit value in input', () => { const payload = createMockCrawlOptions({ limit: 25 }) render() - expect(screen.getByDisplayValue('25')).toBeInTheDocument() + expect(screen.getByDisplayValue('25'))!.toBeInTheDocument() }) it('should display max_depth value in input', () => { const payload = createMockCrawlOptions({ max_depth: 5 }) render() - expect(screen.getByDisplayValue('5')).toBeInTheDocument() + expect(screen.getByDisplayValue('5'))!.toBeInTheDocument() }) it('should display excludes value in input', () => { const payload = createMockCrawlOptions({ excludes: 'test/*' }) render() - expect(screen.getByDisplayValue('test/*')).toBeInTheDocument() + expect(screen.getByDisplayValue('test/*'))!.toBeInTheDocument() }) it('should display includes value in input', () => { const payload = createMockCrawlOptions({ includes: 'docs/*' }) render() - expect(screen.getByDisplayValue('docs/*')).toBeInTheDocument() + expect(screen.getByDisplayValue('docs/*'))!.toBeInTheDocument() }) }) @@ -149,7 +149,7 @@ describe('Options (watercrawl)', () => { const { container } = render() const checkboxes = getCheckboxes(container) - fireEvent.click(checkboxes[0]) + fireEvent.click(checkboxes[0]!) expect(mockOnChange).toHaveBeenCalledWith({ ...payload, @@ -162,7 +162,7 @@ describe('Options (watercrawl)', () => { const { container } = render() const checkboxes = getCheckboxes(container) - fireEvent.click(checkboxes[1]) + fireEvent.click(checkboxes[1]!) expect(mockOnChange).toHaveBeenCalledWith({ ...payload, @@ -264,10 +264,10 @@ describe('Options (watercrawl)', () => { const payload2 = createMockCrawlOptions({ limit: 20 }) const { rerender } = render() - expect(screen.getByDisplayValue('10')).toBeInTheDocument() + expect(screen.getByDisplayValue('10'))!.toBeInTheDocument() rerender() - expect(screen.getByDisplayValue('20')).toBeInTheDocument() + expect(screen.getByDisplayValue('20'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx b/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx index 0d60ef86db..c89059d185 100644 --- a/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx +++ b/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx @@ -105,7 +105,7 @@ describe('Operations', () => { describe('rendering', () => { it('should render without crashing', () => { render() - expect(document.querySelector('.flex.items-center')).toBeInTheDocument() + expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() }) it('should render buttons when embeddingAvailable', () => { @@ -122,7 +122,7 @@ describe('Operations', () => { it('should render disabled switch when embeddingAvailable is false in list scene', () => { render() const disabledSwitch = screen.getByRole('switch') - expect(disabledSwitch).toHaveAttribute('aria-disabled', 'true') + expect(disabledSwitch)!.toHaveAttribute('aria-disabled', 'true') }) }) @@ -209,7 +209,7 @@ describe('Operations', () => { const buttons = screen.getAllByRole('button') const settingsButton = buttons[0] await act(async () => { - fireEvent.click(settingsButton) + fireEvent.click(settingsButton!) }) expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1/settings') }) @@ -219,7 +219,7 @@ describe('Operations', () => { it('should render differently in detail scene', () => { render() const container = document.querySelector('.flex.items-center') - expect(container).toBeInTheDocument() + expect(container)!.toBeInTheDocument() }) it('should not render switch in detail scene', () => { @@ -239,7 +239,7 @@ describe('Operations', () => { onSelectedIdChange={mockOnSelectedIdChange} />, ) - expect(document.querySelector('.flex.items-center')).toBeInTheDocument() + expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() }) }) @@ -257,7 +257,8 @@ describe('Operations', () => { render() await openPopover() // Check if popover content is visible - expect(screen.getByText('datasetDocuments.list.table.rename')).toBeInTheDocument() + // Check if popover content is visible + expect(screen.getByText('datasetDocuments.list.table.rename'))!.toBeInTheDocument() }) it('should call archive when archive action is clicked', async () => { @@ -297,7 +298,8 @@ describe('Operations', () => { fireEvent.click(deleteButton) }) // Check if confirmation modal is shown - expect(screen.getByText('datasetDocuments.list.delete.title')).toBeInTheDocument() + // Check if confirmation modal is shown + expect(screen.getByText('datasetDocuments.list.delete.title'))!.toBeInTheDocument() }) it('should call delete when confirm is clicked in delete modal', async () => { @@ -324,7 +326,8 @@ describe('Operations', () => { fireEvent.click(deleteButton) }) // Verify modal is shown - expect(screen.getByText('datasetDocuments.list.delete.title')).toBeInTheDocument() + // Verify modal is shown + expect(screen.getByText('datasetDocuments.list.delete.title'))!.toBeInTheDocument() // Find and click the cancel button const cancelButton = screen.getByText('common.operation.cancel') await act(async () => { @@ -366,7 +369,7 @@ describe('Operations', () => { await user.click(renameAction) const renameInput = await screen.findByRole('textbox') - expect(renameInput).toHaveValue('Test Document') + expect(renameInput)!.toHaveValue('Test Document') }) it('should call sync for notion data source', async () => { @@ -458,7 +461,7 @@ describe('Operations', () => { />, ) await openPopover() - expect(screen.getByText('datasetDocuments.list.action.download')).toBeInTheDocument() + expect(screen.getByText('datasetDocuments.list.action.download'))!.toBeInTheDocument() }) it('should download archived file when download is clicked', async () => { @@ -543,7 +546,7 @@ describe('Operations', () => { detail={{ ...defaultDetail, display_status: 'indexing' }} />, ) - expect(document.querySelector('.flex.items-center')).toBeInTheDocument() + expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() }) it('should render resume action when status is paused', () => { @@ -553,7 +556,7 @@ describe('Operations', () => { detail={{ ...defaultDetail, display_status: 'paused' }} />, ) - expect(document.querySelector('.flex.items-center')).toBeInTheDocument() + expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() }) it('should not show pause/resume for available status', async () => { @@ -582,7 +585,7 @@ describe('Operations', () => { detail={{ ...defaultDetail, data_source_type: 'notion_import' }} />, ) - expect(document.querySelector('.flex.items-center')).toBeInTheDocument() + expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() }) it('should handle web data source type', () => { @@ -592,7 +595,7 @@ describe('Operations', () => { detail={{ ...defaultDetail, data_source_type: 'website_crawl' }} />, ) - expect(document.querySelector('.flex.items-center')).toBeInTheDocument() + expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() }) it('should not show download for non-file data source', async () => { @@ -622,7 +625,7 @@ describe('Operations', () => { it('should accept custom className prop', () => { // The className is passed to CustomPopover, verify component renders without errors render() - expect(document.querySelector('.flex.items-center')).toBeInTheDocument() + expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx b/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx index 97ae1c92a1..01d6299492 100644 --- a/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx @@ -112,25 +112,26 @@ describe('DocumentList', () => { describe('Rendering', () => { it('should render without crashing', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should render all documents', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('Document 1.txt')).toBeInTheDocument() - expect(screen.getByText('Document 2.txt')).toBeInTheDocument() - expect(screen.getByText('Document 3.txt')).toBeInTheDocument() + expect(screen.getByText('Document 1.txt'))!.toBeInTheDocument() + expect(screen.getByText('Document 2.txt'))!.toBeInTheDocument() + expect(screen.getByText('Document 3.txt'))!.toBeInTheDocument() }) it('should render table headers', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('#')).toBeInTheDocument() + expect(screen.getByText('#'))!.toBeInTheDocument() }) it('should render pagination when total is provided', () => { render(, { wrapper: createWrapper() }) // Pagination component should be present - expect(screen.getByRole('table')).toBeInTheDocument() + // Pagination component should be present + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should not render pagination when total is 0', () => { @@ -139,13 +140,13 @@ describe('DocumentList', () => { pagination: { ...defaultPagination, total: 0 }, } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should render empty table when no documents', () => { const props = { ...defaultProps, documents: [] } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) }) @@ -165,7 +166,8 @@ describe('DocumentList', () => { const props = { ...defaultProps, embeddingAvailable: false } render(, { wrapper: createWrapper() }) // Row checkboxes should still be there, but header checkbox should be hidden - expect(screen.getByRole('table')).toBeInTheDocument() + // Row checkboxes should still be there, but header checkbox should be hidden + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should call onSelectedIdChange when select all is clicked', () => { @@ -175,7 +177,7 @@ describe('DocumentList', () => { const checkboxes = findCheckboxes(container) if (checkboxes.length > 0) { - fireEvent.click(checkboxes[0]) + fireEvent.click(checkboxes[0]!) expect(onSelectedIdChange).toHaveBeenCalled() } }) @@ -190,7 +192,7 @@ describe('DocumentList', () => { // When checked, checkbox should have a check icon (svg) inside props.selectedIds.forEach((id) => { const checkIcon = screen.getByTestId(`check-icon-doc-row-${id}`) - expect(checkIcon).toBeInTheDocument() + expect(checkIcon)!.toBeInTheDocument() }) }) @@ -206,7 +208,9 @@ describe('DocumentList', () => { expect(checkboxes.length).toBeGreaterThan(0) // Header checkbox should show indeterminate icon, not check icon // Just verify it's rendered - expect(checkboxes[0]).toBeInTheDocument() + // Header checkbox should show indeterminate icon, not check icon + // Just verify it's rendered + expect(checkboxes[0])!.toBeInTheDocument() }) it('should call onSelectedIdChange with single document when row checkbox is clicked', () => { @@ -216,7 +220,7 @@ describe('DocumentList', () => { const checkboxes = findCheckboxes(container) if (checkboxes.length > 1) { - fireEvent.click(checkboxes[1]) + fireEvent.click(checkboxes[1]!) expect(onSelectedIdChange).toHaveBeenCalled() } }) @@ -236,7 +240,7 @@ describe('DocumentList', () => { const sortableHeaders = container.querySelectorAll('thead button') if (sortableHeaders.length > 0) - fireEvent.click(sortableHeaders[0]) + fireEvent.click(sortableHeaders[0]!) expect(onSortChange).toHaveBeenCalled() }) @@ -251,14 +255,16 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // BatchAction component should be visible - expect(screen.getByRole('table')).toBeInTheDocument() + // BatchAction component should be visible + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should not show batch action bar when no documents selected', () => { render(, { wrapper: createWrapper() }) // BatchAction should not be present - expect(screen.getByRole('table')).toBeInTheDocument() + // BatchAction should not be present + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should render batch action bar with archive option', () => { @@ -269,7 +275,8 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // BatchAction component should be visible when documents are selected - expect(screen.getByRole('table')).toBeInTheDocument() + // BatchAction component should be visible when documents are selected + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should render batch action bar with enable option', () => { @@ -279,7 +286,7 @@ describe('DocumentList', () => { } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should render batch action bar with disable option', () => { @@ -289,7 +296,7 @@ describe('DocumentList', () => { } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should render batch action bar with delete option', () => { @@ -299,7 +306,7 @@ describe('DocumentList', () => { } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should clear selection when cancel is clicked', () => { @@ -329,7 +336,8 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // BatchAction should be visible - expect(screen.getByRole('table')).toBeInTheDocument() + // BatchAction should be visible + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should show re-index option for error documents', () => { @@ -343,7 +351,8 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // BatchAction with re-index should be present for error documents - expect(screen.getByRole('table')).toBeInTheDocument() + // BatchAction with re-index should be present for error documents + expect(screen.getByRole('table'))!.toBeInTheDocument() }) }) @@ -354,7 +363,7 @@ describe('DocumentList', () => { const rows = screen.getAllByRole('row') // First row is header, second row is first document if (rows.length > 1) { - fireEvent.click(rows[1]) + fireEvent.click(rows[1]!) expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1') } }) @@ -376,11 +385,11 @@ describe('DocumentList', () => { const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md') if (renameButtons.length > 0) { await act(async () => { - fireEvent.click(renameButtons[0]) + fireEvent.click(renameButtons[0]!) }) } - expect(screen.getByRole('dialog', { name: 'datasetDocuments.list.table.rename' })).toBeInTheDocument() + expect(screen.getByRole('dialog', { name: 'datasetDocuments.list.table.rename' }))!.toBeInTheDocument() }) it('should call onUpdate when document is renamed', () => { @@ -389,7 +398,8 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // The handleRenamed callback wraps onUpdate - expect(screen.getByRole('table')).toBeInTheDocument() + // The handleRenamed callback wraps onUpdate + expect(screen.getByRole('table'))!.toBeInTheDocument() }) }) @@ -408,7 +418,7 @@ describe('DocumentList', () => { }) } - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should call onManageMetadata when manage metadata is triggered', () => { @@ -421,26 +431,27 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // The onShowManage callback in EditMetadataBatchModal should call hideEditModal then onManageMetadata - expect(screen.getByRole('table')).toBeInTheDocument() + // The onShowManage callback in EditMetadataBatchModal should call hideEditModal then onManageMetadata + expect(screen.getByRole('table'))!.toBeInTheDocument() }) }) describe('Chunking Mode', () => { it('should render with general mode', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should render with QA mode', () => { // This test uses the default mock which returns ChunkingMode.text // The component will compute isQAMode based on doc_form render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should render with parent-child mode', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) }) @@ -449,7 +460,7 @@ describe('DocumentList', () => { const props = { ...defaultProps, documents: [] } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should handle documents with missing optional fields', () => { @@ -463,7 +474,7 @@ describe('DocumentList', () => { } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should handle remote sort value', () => { @@ -473,7 +484,7 @@ describe('DocumentList', () => { } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }) it('should handle large number of documents', () => { @@ -482,7 +493,7 @@ describe('DocumentList', () => { const props = { ...defaultProps, documents: manyDocs } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.getByRole('table'))!.toBeInTheDocument() }, 10000) }) }) diff --git a/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx index d5e4f480be..b6b02ed829 100644 --- a/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx +++ b/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx @@ -103,23 +103,23 @@ describe('DocumentTableRow', () => { describe('Rendering', () => { it('should render without crashing', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('test-document.txt')).toBeInTheDocument() + expect(screen.getByText('test-document.txt'))!.toBeInTheDocument() }) it('should render index number correctly', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('6')).toBeInTheDocument() + expect(screen.getByText('6'))!.toBeInTheDocument() }) it('should render document name with tooltip', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('test-document.txt')).toBeInTheDocument() + expect(screen.getByText('test-document.txt'))!.toBeInTheDocument() }) it('should render checkbox element', () => { const { container } = render(, { wrapper: createWrapper() }) const checkbox = findCheckbox(container) - expect(checkbox).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() }) }) @@ -127,14 +127,14 @@ describe('DocumentTableRow', () => { it('should show check icon when isSelected is true', () => { const { container } = render(, { wrapper: createWrapper() }) const checkbox = findCheckbox(container) - expect(checkbox).toBeInTheDocument() - expect(screen.getByTestId('check-icon-doc-row-doc-1')).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() + expect(screen.getByTestId('check-icon-doc-row-doc-1'))!.toBeInTheDocument() }) it('should not show check icon when isSelected is false', () => { const { container } = render(, { wrapper: createWrapper() }) const checkbox = findCheckbox(container) - expect(checkbox).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() expect(screen.queryByTestId('check-icon-doc-row-doc-1')).not.toBeInTheDocument() }) @@ -200,13 +200,13 @@ describe('DocumentTableRow', () => { it('should display word count less than 1000 as is', () => { const doc = createMockDoc({ word_count: 500 }) render(, { wrapper: createWrapper() }) - expect(screen.getByText('500')).toBeInTheDocument() + expect(screen.getByText('500'))!.toBeInTheDocument() }) it('should display word count 1000 or more in k format', () => { const doc = createMockDoc({ word_count: 1500 }) render(, { wrapper: createWrapper() }) - expect(screen.getByText('1.5k')).toBeInTheDocument() + expect(screen.getByText('1.5k'))!.toBeInTheDocument() }) it('should display 0 with empty style when word_count is 0', () => { @@ -219,7 +219,7 @@ describe('DocumentTableRow', () => { it('should handle undefined word_count', () => { const doc = createMockDoc({ word_count: undefined as unknown as number }) const { container } = render(, { wrapper: createWrapper() }) - expect(container).toBeInTheDocument() + expect(container)!.toBeInTheDocument() }) }) @@ -227,13 +227,13 @@ describe('DocumentTableRow', () => { it('should display hit count less than 1000 as is', () => { const doc = createMockDoc({ hit_count: 100 }) render(, { wrapper: createWrapper() }) - expect(screen.getByText('100')).toBeInTheDocument() + expect(screen.getByText('100'))!.toBeInTheDocument() }) it('should display hit count 1000 or more in k format', () => { const doc = createMockDoc({ hit_count: 2500 }) render(, { wrapper: createWrapper() }) - expect(screen.getByText('2.5k')).toBeInTheDocument() + expect(screen.getByText('2.5k'))!.toBeInTheDocument() }) it('should display 0 with empty style when hit_count is 0', () => { @@ -248,12 +248,13 @@ describe('DocumentTableRow', () => { it('should render ChunkingModeLabel with general mode', () => { render(, { wrapper: createWrapper() }) // ChunkingModeLabel should be rendered - expect(screen.getByRole('row')).toBeInTheDocument() + // ChunkingModeLabel should be rendered + expect(screen.getByRole('row'))!.toBeInTheDocument() }) it('should render ChunkingModeLabel with QA mode', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row')).toBeInTheDocument() + expect(screen.getByRole('row'))!.toBeInTheDocument() }) }) @@ -261,13 +262,13 @@ describe('DocumentTableRow', () => { it('should render SummaryStatus when summary_index_status is present', () => { const doc = createMockDoc({ summary_index_status: 'completed' }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row')).toBeInTheDocument() + expect(screen.getByRole('row'))!.toBeInTheDocument() }) it('should not render SummaryStatus when summary_index_status is absent', () => { const doc = createMockDoc({ summary_index_status: undefined }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row')).toBeInTheDocument() + expect(screen.getByRole('row'))!.toBeInTheDocument() }) }) @@ -282,7 +283,7 @@ describe('DocumentTableRow', () => { // Find the rename button by finding the RiEditLine icon's parent const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md') if (renameButtons.length > 0) { - fireEvent.click(renameButtons[0]) + fireEvent.click(renameButtons[0]!) expect(onShowRenameModal).toHaveBeenCalledWith(defaultProps.doc) expect(mockPush).not.toHaveBeenCalled() } @@ -292,13 +293,13 @@ describe('DocumentTableRow', () => { describe('Operations', () => { it('should pass selectedIds to Operations component', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row')).toBeInTheDocument() + expect(screen.getByRole('row'))!.toBeInTheDocument() }) it('should pass onSelectedIdChange to Operations component', () => { const onSelectedIdChange = vi.fn() render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row')).toBeInTheDocument() + expect(screen.getByRole('row'))!.toBeInTheDocument() }) }) @@ -306,7 +307,7 @@ describe('DocumentTableRow', () => { it('should render with FILE data source type', () => { const doc = createMockDoc({ data_source_type: DataSourceType.FILE }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row')).toBeInTheDocument() + expect(screen.getByRole('row'))!.toBeInTheDocument() }) it('should render with NOTION data source type', () => { @@ -315,13 +316,13 @@ describe('DocumentTableRow', () => { data_source_info: { notion_page_icon: 'icon.png' }, }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row')).toBeInTheDocument() + expect(screen.getByRole('row'))!.toBeInTheDocument() }) it('should render with WEB data source type', () => { const doc = createMockDoc({ data_source_type: DataSourceType.WEB }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row')).toBeInTheDocument() + expect(screen.getByRole('row'))!.toBeInTheDocument() }) }) @@ -329,13 +330,13 @@ describe('DocumentTableRow', () => { it('should handle document with very long name', () => { const doc = createMockDoc({ name: `${'a'.repeat(500)}.txt` }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row')).toBeInTheDocument() + expect(screen.getByRole('row'))!.toBeInTheDocument() }) it('should handle document with special characters in name', () => { const doc = createMockDoc({ name: '.txt' }) render(, { wrapper: createWrapper() }) - expect(screen.getByText('.txt')).toBeInTheDocument() + expect(screen.getByText('.txt'))!.toBeInTheDocument() }) it('should memoize the component', () => { @@ -343,7 +344,7 @@ describe('DocumentTableRow', () => { const { rerender } = render(, { wrapper }) rerender() - expect(screen.getByRole('row')).toBeInTheDocument() + expect(screen.getByRole('row'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx index 899c70e216..c8a06ea807 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx @@ -43,7 +43,7 @@ const Options = ({ if (!result.success) { const issues = result.error.issues const firstIssue = issues[0] - const errorMessage = `"${firstIssue.path.join('.')}" ${firstIssue.message}` + const errorMessage = `"${firstIssue!.path.join('.')}" ${firstIssue!.message}` toast.error(errorMessage) return errorMessage } diff --git a/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx b/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx index 33703d56b2..7fde02adcd 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx @@ -33,7 +33,7 @@ const Form = ({ if (!result.success) { const issues = result.error.issues const firstIssue = issues[0] - const errorMessage = `"${firstIssue.path.join('.')}" ${firstIssue.message}` + const errorMessage = `"${firstIssue!.path.join('.')}" ${firstIssue!.message}` toast.error(errorMessage) return errorMessage } diff --git a/web/app/components/datasets/hit-testing/components/query-input/__tests__/index.spec.tsx b/web/app/components/datasets/hit-testing/components/query-input/__tests__/index.spec.tsx index 25b7abe7ea..d9427f5117 100644 --- a/web/app/components/datasets/hit-testing/components/query-input/__tests__/index.spec.tsx +++ b/web/app/components/datasets/hit-testing/components/query-input/__tests__/index.spec.tsx @@ -79,17 +79,17 @@ describe('QueryInput', () => { it('should render title', () => { render() - expect(screen.getByText('datasetHitTesting.input.title')).toBeInTheDocument() + expect(screen.getByText('datasetHitTesting.input.title'))!.toBeInTheDocument() }) it('should render textarea with query text', () => { render() - expect(screen.getByTestId('textarea')).toBeInTheDocument() + expect(screen.getByTestId('textarea'))!.toBeInTheDocument() }) it('should render submit button', () => { render() - expect(screen.getByRole('button', { name: /input\.testing/ })).toBeInTheDocument() + expect(screen.getByRole('button', { name: /input\.testing/ }))!.toBeInTheDocument() }) it('should disable submit button when text is empty', () => { @@ -98,17 +98,17 @@ describe('QueryInput', () => { queries: [{ content: '', content_type: 'text_query', file_info: null }] satisfies Query[], } render() - expect(screen.getByRole('button', { name: /input\.testing/ })).toBeDisabled() + expect(screen.getByRole('button', { name: /input\.testing/ }))!.toBeDisabled() }) it('should render retrieval method for non-external mode', () => { render() - expect(screen.getByText('dataset.retrieval.semantic_search.title')).toBeInTheDocument() + expect(screen.getByText('dataset.retrieval.semantic_search.title'))!.toBeInTheDocument() }) it('should render settings button for external mode', () => { render() - expect(screen.getByText('datasetHitTesting.settingTitle')).toBeInTheDocument() + expect(screen.getByText('datasetHitTesting.settingTitle'))!.toBeInTheDocument() }) it('should disable submit button when text exceeds 200 characters', () => { @@ -117,15 +117,15 @@ describe('QueryInput', () => { queries: [{ content: 'a'.repeat(201), content_type: 'text_query', file_info: null }] satisfies Query[], } render() - expect(screen.getByRole('button', { name: /input\.testing/ })).toBeDisabled() + expect(screen.getByRole('button', { name: /input\.testing/ }))!.toBeDisabled() }) it('should show loading state on submit button when loading', () => { render() const submitButton = screen.getByRole('button', { name: /input\.testing/ }) - expect(submitButton).toBeDisabled() - expect(submitButton).toHaveAttribute('aria-busy', 'true') - expect(submitButton.querySelector('.animate-spin')).toBeInTheDocument() + expect(submitButton)!.toBeDisabled() + expect(submitButton)!.toHaveAttribute('aria-busy', 'true') + expect(submitButton.querySelector('.animate-spin'))!.toBeInTheDocument() }) // Cover line 83: images useMemo with image_query data @@ -141,6 +141,37 @@ describe('QueryInput', () => { ] render() + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image + // Submit should be enabled since we have text + uploaded image // Submit should be enabled since we have text + uploaded image expect(screen.getByRole('button', { name: /input\.testing/ })).not.toBeDisabled() }) @@ -153,7 +184,7 @@ describe('QueryInput', () => { // Click settings button to open modal fireEvent.click(screen.getByRole('button', { name: /settingTitle/ })) - expect(screen.getByTestId('external-retrieval-modal')).toBeInTheDocument() + expect(screen.getByTestId('external-retrieval-modal'))!.toBeInTheDocument() // Close modal fireEvent.click(screen.getByTestId('modal-close')) @@ -165,7 +196,7 @@ describe('QueryInput', () => { // Open modal fireEvent.click(screen.getByRole('button', { name: /settingTitle/ })) - expect(screen.getByTestId('external-retrieval-modal')).toBeInTheDocument() + expect(screen.getByTestId('external-retrieval-modal'))!.toBeInTheDocument() // Save settings fireEvent.click(screen.getByTestId('modal-save')) @@ -274,7 +305,7 @@ describe('QueryInput', () => { ]), ) // Should not contain image_query - const calledWith = defaultProps.setQueries.mock.calls[0][0] as Query[] + const calledWith = defaultProps.setQueries.mock.calls[0]![0] as Query[] expect(calledWith.filter(q => q.content_type === 'image_query')).toHaveLength(0) }) }) @@ -412,7 +443,7 @@ describe('QueryInput', () => { it('should show keyword_search when isEconomy is true', () => { render() - expect(screen.getByText('dataset.retrieval.keyword_search.title')).toBeInTheDocument() + expect(screen.getByText('dataset.retrieval.keyword_search.title'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/metadata/edit-metadata-batch/__tests__/modal.spec.tsx b/web/app/components/datasets/metadata/edit-metadata-batch/__tests__/modal.spec.tsx index d9b88e20bb..40c925222c 100644 --- a/web/app/components/datasets/metadata/edit-metadata-batch/__tests__/modal.spec.tsx +++ b/web/app/components/datasets/metadata/edit-metadata-batch/__tests__/modal.spec.tsx @@ -120,14 +120,14 @@ describe('EditMetadataBatchModal', () => { it('should render without crashing', async () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) }) it('should render document count', async () => { render() await waitFor(() => { - expect(screen.getByText(/5/)).toBeInTheDocument() + expect(screen.getByText(/5/))!.toBeInTheDocument() }) }) @@ -142,8 +142,8 @@ describe('EditMetadataBatchModal', () => { it('should render field names for existing items', async () => { render() await waitFor(() => { - expect(screen.getByText('field_one')).toBeInTheDocument() - expect(screen.getByText('field_two')).toBeInTheDocument() + expect(screen.getByText('field_one'))!.toBeInTheDocument() + expect(screen.getByText('field_two'))!.toBeInTheDocument() }) }) @@ -158,7 +158,7 @@ describe('EditMetadataBatchModal', () => { it('should render select metadata modal', async () => { render() await waitFor(() => { - expect(screen.getByTestId('select-modal')).toBeInTheDocument() + expect(screen.getByTestId('select-modal'))!.toBeInTheDocument() }) }) }) @@ -169,7 +169,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) const cancelButton = screen.getByText(/cancel/i) @@ -183,7 +183,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) // Find the primary save button (not the one in SelectMetadataModal) @@ -196,17 +196,17 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) const checkboxContainer = document.querySelector('[data-testid*="checkbox"]') - expect(checkboxContainer).toBeInTheDocument() + expect(checkboxContainer)!.toBeInTheDocument() if (checkboxContainer) { fireEvent.click(checkboxContainer) await waitFor(() => { const checkIcon = screen.getByTestId('check-icon-apply-to-all') - expect(checkIcon).toBeInTheDocument() + expect(checkIcon)!.toBeInTheDocument() }) } }) @@ -216,7 +216,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) }) }) @@ -226,7 +226,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByTestId('change-1')) @@ -239,7 +239,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByTestId('remove-1')) @@ -252,7 +252,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) // First change the item @@ -269,14 +269,14 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByTestId('select-metadata')) // Should now have add-row for the new item await waitFor(() => { - expect(screen.getByTestId('add-row')).toBeInTheDocument() + expect(screen.getByTestId('add-row'))!.toBeInTheDocument() }) }) @@ -284,14 +284,14 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) // First add an item fireEvent.click(screen.getByTestId('select-metadata')) await waitFor(() => { - expect(screen.getByTestId('add-row')).toBeInTheDocument() + expect(screen.getByTestId('add-row'))!.toBeInTheDocument() }) // Then remove it @@ -306,20 +306,20 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) // First add an item fireEvent.click(screen.getByTestId('select-metadata')) await waitFor(() => { - expect(screen.getByTestId('add-row')).toBeInTheDocument() + expect(screen.getByTestId('add-row'))!.toBeInTheDocument() }) // Then change it fireEvent.click(screen.getByTestId('add-change-new-1')) - expect(screen.getByTestId('add-row')).toBeInTheDocument() + expect(screen.getByTestId('add-row'))!.toBeInTheDocument() }) it('should call doAddMetaData when saving new metadata with valid name', async () => { @@ -328,7 +328,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByTestId('save-metadata')) @@ -344,7 +344,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByTestId('save-metadata')) @@ -368,7 +368,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByTestId('save-metadata')) @@ -388,7 +388,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByTestId('manage-metadata')) @@ -401,14 +401,14 @@ describe('EditMetadataBatchModal', () => { it('should pass correct datasetId', async () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) }) it('should display correct document number', async () => { render() await waitFor(() => { - expect(screen.getByText(/10/)).toBeInTheDocument() + expect(screen.getByText(/10/))!.toBeInTheDocument() }) }) @@ -427,7 +427,7 @@ describe('EditMetadataBatchModal', () => { ] render() await waitFor(() => { - expect(screen.getByTestId('edit-row')).toBeInTheDocument() + expect(screen.getByTestId('edit-row'))!.toBeInTheDocument() }) }) @@ -436,7 +436,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) // Find the primary save button @@ -453,7 +453,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) @@ -470,7 +470,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) const checkboxContainer = document.querySelector('[data-testid*="checkbox"]') @@ -493,7 +493,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) // Remove an item @@ -503,7 +503,7 @@ describe('EditMetadataBatchModal', () => { expect(onSave).toHaveBeenCalled() // The first argument should not contain the deleted item (id '1') - const savedList = onSave.mock.calls[0][0] as MetadataItemInBatchEdit[] + const savedList = onSave.mock.calls[0]![0] as MetadataItemInBatchEdit[] const hasDeletedItem = savedList.some(item => item.id === '1') expect(hasDeletedItem).toBe(false) }) @@ -512,13 +512,13 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByRole('dialog'))!.toBeInTheDocument() }) // Add first item fireEvent.click(screen.getByTestId('select-metadata')) await waitFor(() => { - expect(screen.getByTestId('add-row')).toBeInTheDocument() + expect(screen.getByTestId('add-row'))!.toBeInTheDocument() }) // Remove it @@ -531,7 +531,7 @@ describe('EditMetadataBatchModal', () => { // Add again fireEvent.click(screen.getByTestId('select-metadata')) await waitFor(() => { - expect(screen.getByTestId('add-row')).toBeInTheDocument() + expect(screen.getByTestId('add-row'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx b/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx index ddd624a076..71f23324f6 100644 --- a/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx +++ b/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx @@ -99,7 +99,7 @@ describe('MetadataDocument', () => { docDetail={mockDocDetail as Parameters[0]['docDetail']} />, ) - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should render metadata fields when hasData is true', () => { @@ -110,8 +110,8 @@ describe('MetadataDocument', () => { docDetail={mockDocDetail as Parameters[0]['docDetail']} />, ) - expect(screen.getByText('field_one')).toBeInTheDocument() - expect(screen.getByText('field_two')).toBeInTheDocument() + expect(screen.getByText('field_one'))!.toBeInTheDocument() + expect(screen.getByText('field_two'))!.toBeInTheDocument() }) it('should render no-data state when hasData is false and not in edit mode', () => { @@ -147,8 +147,8 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText(/save/i)).toBeInTheDocument() - expect(screen.getByText(/cancel/i)).toBeInTheDocument() + expect(screen.getByText(/save/i))!.toBeInTheDocument() + expect(screen.getByText(/cancel/i))!.toBeInTheDocument() }) it('should render built-in section when builtInEnabled is true', () => { @@ -166,7 +166,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('created_at')).toBeInTheDocument() + expect(screen.getByText('created_at'))!.toBeInTheDocument() }) it('should render divider when builtInEnabled is true', () => { @@ -185,7 +185,7 @@ describe('MetadataDocument', () => { ) const divider = container.querySelector('.bg-linear-to-r') - expect(divider).toBeInTheDocument() + expect(divider)!.toBeInTheDocument() }) it('should render origin info section', () => { @@ -202,7 +202,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('source')).toBeInTheDocument() + expect(screen.getByText('source'))!.toBeInTheDocument() }) it('should render technical parameters section', () => { @@ -219,7 +219,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('word_count')).toBeInTheDocument() + expect(screen.getByText('word_count'))!.toBeInTheDocument() }) it('should render all sections together', () => { @@ -239,10 +239,10 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('field_one')).toBeInTheDocument() - expect(screen.getByText('created_at')).toBeInTheDocument() - expect(screen.getByText('source')).toBeInTheDocument() - expect(screen.getByText('word_count')).toBeInTheDocument() + expect(screen.getByText('field_one'))!.toBeInTheDocument() + expect(screen.getByText('created_at'))!.toBeInTheDocument() + expect(screen.getByText('source'))!.toBeInTheDocument() + expect(screen.getByText('word_count'))!.toBeInTheDocument() }) }) @@ -255,7 +255,7 @@ describe('MetadataDocument', () => { docDetail={mockDocDetail as Parameters[0]['docDetail']} />, ) - expect(screen.getByText(/edit/i)).toBeInTheDocument() + expect(screen.getByText(/edit/i))!.toBeInTheDocument() }) it('should call startToEdit when edit button is clicked', () => { @@ -362,8 +362,9 @@ describe('MetadataDocument', () => { ) // Should show save/cancel buttons - expect(screen.getByText(/save/i)).toBeInTheDocument() - expect(screen.getByText(/cancel/i)).toBeInTheDocument() + // Should show save/cancel buttons + expect(screen.getByText(/save/i))!.toBeInTheDocument() + expect(screen.getByText(/cancel/i))!.toBeInTheDocument() }) }) @@ -386,7 +387,7 @@ describe('MetadataDocument', () => { const inputs = container.querySelectorAll('input') if (inputs.length > 0) { - fireEvent.change(inputs[0], { target: { value: 'new value' } }) + fireEvent.change(inputs[0]!, { target: { value: 'new value' } }) await waitFor(() => { expect(setTempList).toHaveBeenCalled() @@ -454,7 +455,7 @@ describe('MetadataDocument', () => { const inputs = container.querySelectorAll('input') if (inputs.length > 0) { - fireEvent.change(inputs[0], { target: { value: 'updated' } }) + fireEvent.change(inputs[0]!, { target: { value: 'updated' } }) await waitFor(() => { expect(setTempList).toHaveBeenCalled() }) @@ -483,7 +484,7 @@ describe('MetadataDocument', () => { expect(deleteContainers.length).toBeGreaterThan(0) if (deleteContainers.length > 0) { - const deleteIcon = deleteContainers[0].querySelector('svg') + const deleteIcon = deleteContainers[0]!.querySelector('svg') if (deleteIcon) fireEvent.click(deleteIcon) @@ -504,7 +505,7 @@ describe('MetadataDocument', () => { className="custom-class" />, ) - expect(container.firstChild).toHaveClass('custom-class') + expect(container.firstChild)!.toHaveClass('custom-class') }) it('should use tempList when in edit mode', () => { @@ -524,7 +525,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('temp_field')).toBeInTheDocument() + expect(screen.getByText('temp_field'))!.toBeInTheDocument() }) it('should use list when not in edit mode', () => { @@ -536,8 +537,8 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('field_one')).toBeInTheDocument() - expect(screen.getByText('field_two')).toBeInTheDocument() + expect(screen.getByText('field_one'))!.toBeInTheDocument() + expect(screen.getByText('field_two'))!.toBeInTheDocument() }) it('should pass datasetId to child components', () => { @@ -549,7 +550,8 @@ describe('MetadataDocument', () => { />, ) // Component should render without errors - expect(screen.getByText('field_one')).toBeInTheDocument() + // Component should render without errors + expect(screen.getByText('field_one'))!.toBeInTheDocument() }) }) @@ -588,6 +590,37 @@ describe('MetadataDocument', () => { />, ) + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered + // NoData component should not be rendered // NoData component should not be rendered expect(screen.queryByText(/start/i)).not.toBeInTheDocument() }) @@ -607,6 +640,37 @@ describe('MetadataDocument', () => { />, ) + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined + // headerRight should be null/undefined // headerRight should be null/undefined expect(screen.queryByText(/^edit$/i)).not.toBeInTheDocument() }) @@ -628,7 +692,7 @@ describe('MetadataDocument', () => { docDetail={mockDocDetail as Parameters[0]['docDetail']} />, ) - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should render correctly with minimal props', () => { @@ -639,7 +703,7 @@ describe('MetadataDocument', () => { docDetail={mockDocDetail as Parameters[0]['docDetail']} />, ) - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should handle switching between view and edit mode', () => { @@ -651,7 +715,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText(/edit/i)).toBeInTheDocument() + expect(screen.getByText(/edit/i))!.toBeInTheDocument() unmount() @@ -668,8 +732,8 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText(/save/i)).toBeInTheDocument() - expect(screen.getByText(/cancel/i)).toBeInTheDocument() + expect(screen.getByText(/save/i))!.toBeInTheDocument() + expect(screen.getByText(/cancel/i))!.toBeInTheDocument() }) it('should handle multiple items in all sections', () => { @@ -702,11 +766,11 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('user_field_1')).toBeInTheDocument() - expect(screen.getByText('user_field_2')).toBeInTheDocument() - expect(screen.getByText('created_at')).toBeInTheDocument() - expect(screen.getByText('source')).toBeInTheDocument() - expect(screen.getByText('word_count')).toBeInTheDocument() + expect(screen.getByText('user_field_1'))!.toBeInTheDocument() + expect(screen.getByText('user_field_2'))!.toBeInTheDocument() + expect(screen.getByText('created_at'))!.toBeInTheDocument() + expect(screen.getByText('source'))!.toBeInTheDocument() + expect(screen.getByText('word_count'))!.toBeInTheDocument() }) it('should handle null values in metadata', () => { @@ -725,7 +789,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('null_field')).toBeInTheDocument() + expect(screen.getByText('null_field'))!.toBeInTheDocument() }) it('should handle undefined values in metadata', () => { @@ -744,7 +808,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('undefined_field')).toBeInTheDocument() + expect(screen.getByText('undefined_field'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx b/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx index 7441274155..5e81611fc4 100644 --- a/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx +++ b/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx @@ -19,12 +19,12 @@ describe('IndexMethod', () => { describe('Rendering', () => { it('should render without crashing', () => { render() - expect(screen.getByText(/stepTwo\.qualified/)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.qualified/))!.toBeInTheDocument() }) it('should render High Quality option', () => { render() - expect(screen.getByText(/stepTwo\.qualified/)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.qualified/))!.toBeInTheDocument() }) it('should render Economy option', () => { @@ -34,17 +34,17 @@ describe('IndexMethod', () => { it('should render High Quality description', () => { render() - expect(screen.getByText(/form\.indexMethodHighQualityTip/)).toBeInTheDocument() + expect(screen.getByText(/form\.indexMethodHighQualityTip/))!.toBeInTheDocument() }) it('should render Economy description', () => { render() - expect(screen.getByText(/form\.indexMethodEconomyTip/)).toBeInTheDocument() + expect(screen.getByText(/form\.indexMethodEconomyTip/))!.toBeInTheDocument() }) it('should render recommended badge on High Quality', () => { render() - expect(screen.getByText(/stepTwo\.recommend/)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.recommend/))!.toBeInTheDocument() }) }) @@ -82,7 +82,7 @@ describe('IndexMethod', () => { // Find and click Economy option - use getAllByText and get the first one (title) const economyTitles = screen.getAllByText(/form\.indexMethodEconomy/) const economyTitle = economyTitles[0] - const card = economyTitle.closest('div')?.parentElement?.parentElement?.parentElement + const card = economyTitle!.closest('div')?.parentElement?.parentElement?.parentElement fireEvent.click(card!) expect(handleChange).toHaveBeenCalledWith(IndexingType.ECONOMICAL) @@ -114,7 +114,7 @@ describe('IndexMethod', () => { // Try to click Economy option - use getAllByText and get the first one (title) const economyTitles = screen.getAllByText(/form\.indexMethodEconomy/) const economyTitle = economyTitles[0] - const card = economyTitle.closest('div')?.parentElement?.parentElement?.parentElement + const card = economyTitle!.closest('div')?.parentElement?.parentElement?.parentElement fireEvent.click(card!) // Should not call onChange because Economy is disabled when current is QUALIFIED @@ -125,13 +125,13 @@ describe('IndexMethod', () => { describe('KeywordNumber', () => { it('should render KeywordNumber component inside Economy option', () => { render() - expect(getKeywordSlider()).toBeInTheDocument() + expect(getKeywordSlider())!.toBeInTheDocument() }) it('should pass keywordNumber to KeywordNumber component', () => { render() const input = screen.getByRole('textbox') - expect(input).toHaveValue('25') + expect(input)!.toHaveValue('25') }) it('should call onKeywordNumberChange when KeywordNumber changes', () => { @@ -160,13 +160,13 @@ describe('IndexMethod', () => { it('should show orange effect color for High Quality option', () => { const { container } = render() const orangeEffect = container.querySelector('.bg-util-colors-orange-orange-500') - expect(orangeEffect).toBeInTheDocument() + expect(orangeEffect)!.toBeInTheDocument() }) it('should show indigo effect color for Economy option', () => { const { container } = render() const indigoEffect = container.querySelector('.bg-util-colors-indigo-indigo-600') - expect(indigoEffect).toBeInTheDocument() + expect(indigoEffect)!.toBeInTheDocument() }) }) @@ -188,19 +188,20 @@ describe('IndexMethod', () => { it('should handle undefined currentValue', () => { render() // Should render without error - expect(screen.getByText(/stepTwo\.qualified/)).toBeInTheDocument() + // Should render without error + expect(screen.getByText(/stepTwo\.qualified/))!.toBeInTheDocument() }) it('should handle minimum keywordNumber', () => { render() const input = screen.getByRole('textbox') - expect(input).toHaveValue('0') + expect(input)!.toHaveValue('0') }) it('should handle max keywordNumber', () => { render() const input = screen.getByRole('textbox') - expect(input).toHaveValue('50') + expect(input)!.toHaveValue('50') }) }) }) diff --git a/web/app/components/datasets/settings/index-method/keyword-number.tsx b/web/app/components/datasets/settings/index-method/keyword-number.tsx index feb63c1d65..03992fb027 100644 --- a/web/app/components/datasets/settings/index-method/keyword-number.tsx +++ b/web/app/components/datasets/settings/index-method/keyword-number.tsx @@ -33,7 +33,7 @@ const KeyWordNumber = ({ return (
-
+
{t('form.numberOfKeywords', { ns: 'datasetSettings' })}
- - + + - - + + diff --git a/web/app/components/datasets/settings/permission-selector/index.tsx b/web/app/components/datasets/settings/permission-selector/index.tsx index a7182d8f79..8c31799add 100644 --- a/web/app/components/datasets/settings/permission-selector/index.tsx +++ b/web/app/components/datasets/settings/permission-selector/index.tsx @@ -133,8 +133,8 @@ const PermissionSelector = ({ { selectedMembers.length === 1 && ( ) @@ -143,14 +143,14 @@ const PermissionSelector = ({ selectedMembers.length >= 2 && ( <> diff --git a/web/app/components/evaluation/components/conditions-section/add-condition-select.tsx b/web/app/components/evaluation/components/conditions-section/add-condition-select.tsx index cbe44c05b2..a37a8d7bad 100644 --- a/web/app/components/evaluation/components/conditions-section/add-condition-select.tsx +++ b/web/app/components/evaluation/components/conditions-section/add-condition-select.tsx @@ -8,8 +8,8 @@ import { Select, SelectContent, SelectGroup, - SelectGroupLabel, SelectItem, + SelectLabel, SelectTrigger, } from '@/app/components/base/ui/select' import { useEvaluationStore } from '../../store' @@ -46,7 +46,7 @@ const AddConditionSelect = ({ {metricOptionGroups.map(group => ( - {group.label} + {group.label} {group.options.map(option => (
- {option.itemLabel} - + {option.itemLabel} + {t(getConditionMetricValueTypeTranslationKey(option.valueType))}
diff --git a/web/app/components/evaluation/components/conditions-section/condition-group.tsx b/web/app/components/evaluation/components/conditions-section/condition-group.tsx index c37fb615dc..8649ec7f1d 100644 --- a/web/app/components/evaluation/components/conditions-section/condition-group.tsx +++ b/web/app/components/evaluation/components/conditions-section/condition-group.tsx @@ -15,8 +15,8 @@ import { Select, SelectContent, SelectGroup, - SelectGroupLabel, SelectItem, + SelectLabel, SelectTrigger, SelectValue, } from '@/app/components/base/ui/select' @@ -71,15 +71,15 @@ const ConditionMetricLabel = ({ placeholder, }: ConditionMetricLabelProps) => { if (!metric) - return {placeholder} + return {placeholder} return (
- {metric.itemLabel} + {metric.itemLabel}
- {metric.groupLabel} + {metric.groupLabel}
) } @@ -110,13 +110,13 @@ const ConditionMetricSelect = ({ {groupedMetricOptions.map(group => ( - {group.label} + {group.label} {group.options.map(option => (
{option.itemLabel} - + {t(getConditionMetricValueTypeTranslationKey(option.valueType))}
@@ -139,7 +139,7 @@ const ConditionOperatorSelect = ({ return ( { return ( <>
-
+
- {currentWorkspace?.name[0]?.toLocaleUpperCase()} + {currentWorkspace?.name[0]?.toLocaleUpperCase()}
-
+
{currentWorkspace?.name} {isCurrentWorkspaceOwner && ( @@ -82,7 +82,7 @@ const MembersPage = () => { )}
-
+
{enableBilling && isNotUnlimitedMemberPlan ? (
@@ -116,9 +116,9 @@ const MembersPage = () => {
-
{t('members.name', { ns: 'common' })}
-
{t('members.lastActive', { ns: 'common' })}
-
{t('members.role', { ns: 'common' })}
+
{t('members.name', { ns: 'common' })}
+
{t('members.lastActive', { ns: 'common' })}
+
{t('members.role', { ns: 'common' })}
{ @@ -127,27 +127,27 @@ const MembersPage = () => {
-
+
{account.name} - {account.status === 'pending' && {t('members.pending', { ns: 'common' })}} - {userProfile.email === account.email && {t('members.you', { ns: 'common' })}} + {account.status === 'pending' && {t('members.pending', { ns: 'common' })}} + {userProfile.email === account.email && {t('members.you', { ns: 'common' })}}
-
{account.email}
+
{account.email}
-
{formatTimeFromNow(Number((account.last_active_at || account.created_at)) * 1000)}
+
{formatTimeFromNow(Number((account.last_active_at || account.created_at)) * 1000)}
{isCurrentWorkspaceOwner && account.role === 'owner' && isAllowTransferWorkspace && ( setShowTransferOwnershipModal(true)}> )} {isCurrentWorkspaceOwner && account.role === 'owner' && !isAllowTransferWorkspace && ( -
{RoleMap[account.role] || RoleMap.normal}
+
{RoleMap[account.role] || RoleMap.normal}
)} {isCurrentWorkspaceOwner && account.role !== 'owner' && ( )} {!isCurrentWorkspaceOwner && ( -
{RoleMap[account.role] || RoleMap.normal}
+
{RoleMap[account.role] || RoleMap.normal}
)}
diff --git a/web/app/components/header/account-setting/members-page/operation/index.tsx b/web/app/components/header/account-setting/members-page/operation/index.tsx index 5fb3be0195..dab0c84e5b 100644 --- a/web/app/components/header/account-setting/members-page/operation/index.tsx +++ b/web/app/components/header/account-setting/members-page/operation/index.tsx @@ -1,10 +1,15 @@ 'use client' import type { Member } from '@/models/common' -import { CheckIcon, ChevronDownIcon } from '@heroicons/react/24/outline' import { cn } from '@langgenius/dify-ui/cn' import { memo, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from '@/app/components/base/ui/dropdown-menu' import { toast } from '@/app/components/base/ui/toast' import { useProviderContext } from '@/context/provider-context' import { deleteMemberOrCancelInvitation, updateMemberRole } from '@/service/common' @@ -74,40 +79,50 @@ const Operation = ({ member, operatorRole, onOperate }: IOperationProps) => { } } return ( - - setOpen(prev => !prev)}> -
- {RoleMap[member.role] || RoleMap.normal} - -
-
- -
-
- {roleList.map(role => ( -
handleUpdateMemberRole(role)}> - {role === member.role - ? - :
} -
-
{t(roleI18nKeyMap[role].label, { ns: 'common' })}
-
{t(roleI18nKeyMap[role].tip, { ns: 'common' })}
-
-
- ))} -
-
-
-
+ + } + > + {RoleMap[member.role] || RoleMap.normal} + + + +
+ {roleList.map(role => ( + handleUpdateMemberRole(role)} + > + {role === member.role + ? + : }
-
{t('members.removeFromTeam', { ns: 'common' })}
-
{t('members.removeFromTeamTip', { ns: 'common' })}
+
{t(roleI18nKeyMap[role].label, { ns: 'common' })}
+
{t(roleI18nKeyMap[role].tip, { ns: 'common' })}
-
-
+ + ))}
- - + +
+ + +
+
{t('members.removeFromTeam', { ns: 'common' })}
+
{t('members.removeFromTeamTip', { ns: 'common' })}
+
+
+
+ + ) } export default memo(Operation) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx index 0f507a24c6..be3e9edb93 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx @@ -58,9 +58,9 @@ const CredentialSelector = ({ selectedCredential && (
{ - !selectedCredential.addNewCredential && + !selectedCredential.addNewCredential && } -
{selectedCredential.credential_name}
+
{selectedCredential.credential_name}
{ selectedCredential.from_enterprise && ( Enterprise @@ -71,7 +71,7 @@ const CredentialSelector = ({ } { !selectedCredential && ( -
{t('modelProvider.auth.selectModelCredential', { ns: 'common' })}
+
{t('modelProvider.auth.selectModelCredential', { ns: 'common' })}
) } @@ -98,7 +98,7 @@ const CredentialSelector = ({ { !notAllowAddNewCredential && (
diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/status-indicators.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/status-indicators.tsx index 3cb37fd4e5..cca5846390 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/status-indicators.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/status-indicators.tsx @@ -20,12 +20,12 @@ const StatusIndicators = ({ needsConfiguration, modelProvider, inModelList, disa
e.stopPropagation()}>
{title}
{description && ( -
+
{description}
)} {linkText && linkHref && ( -
+
{ diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx index 3556777395..652630be67 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx @@ -164,7 +164,7 @@ const ModelLoadBalancingModal = ({ provider, configurateMethod, currentCustomCon const prevIndex = newConfigs.findIndex(item => item.credential_id === modelCredential.credential_id && item.name !== '__inherit__') const newIndex = available_credentials.findIndex(c => c.credential_id === modelCredential.credential_id) if (newIndex > -1 && prevIndex > -1) - newConfigs[prevIndex].name = available_credentials[newIndex].credential_name || '' + newConfigs[prevIndex]!.name = available_credentials[newIndex]!.credential_name || '' return { ...prev, configs: newConfigs, diff --git a/web/app/components/header/index.tsx b/web/app/components/header/index.tsx index 560d5f1eaa..d92fd032fd 100644 --- a/web/app/components/header/index.tsx +++ b/web/app/components/header/index.tsx @@ -45,7 +45,7 @@ const Header = () => { const renderLogo = () => (

- + {isBrandingEnabled && systemFeatures.branding.application_title ? systemFeatures.branding.application_title : 'Dify'} {systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo ? ( @@ -91,7 +91,7 @@ const Header = () => { return (
-
+
{renderLogo()}
/
@@ -105,7 +105,7 @@ const Header = () => { {(isCurrentWorkspaceEditor || isCurrentWorkspaceDatasetOperator) && } {!isCurrentWorkspaceDatasetOperator && }
-
+
diff --git a/web/app/components/plugins/marketplace/description/index.tsx b/web/app/components/plugins/marketplace/description/index.tsx index c3c1ab5fb1..c91110b8cc 100644 --- a/web/app/components/plugins/marketplace/description/index.tsx +++ b/web/app/components/plugins/marketplace/description/index.tsx @@ -9,10 +9,10 @@ const Description = () => { return ( <> -

+

{t('marketplace.empower')}

-

+

{ isZhHans && ( <> @@ -29,31 +29,31 @@ const Description = () => { ) } - + {t('category.models')} , - + {t('category.tools')} , - + {t('category.datasources')} , - + {t('category.triggers')} , - + {t('category.agents')} , - + {t('category.extensions')} {t('marketplace.and')} - + {t('category.bundles')} { diff --git a/web/app/components/plugins/marketplace/search-box/__tests__/index.spec.tsx b/web/app/components/plugins/marketplace/search-box/__tests__/index.spec.tsx index 31cb5d8445..8609ba5539 100644 --- a/web/app/components/plugins/marketplace/search-box/__tests__/index.spec.tsx +++ b/web/app/components/plugins/marketplace/search-box/__tests__/index.spec.tsx @@ -124,7 +124,7 @@ describe('SearchBox', () => { it('should render without crashing', () => { render() - expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('textbox'))!.toBeInTheDocument() }) it('should render with marketplace mode styling', () => { @@ -133,7 +133,8 @@ describe('SearchBox', () => { ) // In marketplace mode, TagsFilter comes before input - expect(container.querySelector('.rounded-xl')).toBeInTheDocument() + // In marketplace mode, TagsFilter comes before input + expect(container.querySelector('.rounded-xl'))!.toBeInTheDocument() }) it('should render with non-marketplace mode styling', () => { @@ -142,25 +143,26 @@ describe('SearchBox', () => { ) // In non-marketplace mode, search icon appears first - expect(container.querySelector('.rounded-lg')).toBeInTheDocument() + // In non-marketplace mode, search icon appears first + expect(container.querySelector('.rounded-lg'))!.toBeInTheDocument() }) it('should render placeholder correctly', () => { render() - expect(screen.getByPlaceholderText('Search here...')).toBeInTheDocument() + expect(screen.getByPlaceholderText('Search here...'))!.toBeInTheDocument() }) it('should render search input with current value', () => { render() - expect(screen.getByDisplayValue('test query')).toBeInTheDocument() + expect(screen.getByDisplayValue('test query'))!.toBeInTheDocument() }) it('should render TagsFilter component', () => { render() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-elem'))!.toBeInTheDocument() }) }) @@ -175,8 +177,9 @@ describe('SearchBox', () => { const input = screen.getByRole('textbox') // Both should be rendered - expect(portalElem).toBeInTheDocument() - expect(input).toBeInTheDocument() + // Both should be rendered + expect(portalElem)!.toBeInTheDocument() + expect(input)!.toBeInTheDocument() }) it('should render clear button when search has value in marketplace mode', () => { @@ -207,7 +210,8 @@ describe('SearchBox', () => { ) // Search icon should be present - expect(container.querySelector('.text-components-input-text-placeholder')).toBeInTheDocument() + // Search icon should be present + expect(container.querySelector('.text-components-input-text-placeholder'))!.toBeInTheDocument() }) it('should render clear button when search has value', () => { @@ -223,8 +227,8 @@ describe('SearchBox', () => { const portalElem = screen.getByTestId('portal-elem') const input = screen.getByRole('textbox') - expect(portalElem).toBeInTheDocument() - expect(input).toBeInTheDocument() + expect(portalElem)!.toBeInTheDocument() + expect(input)!.toBeInTheDocument() }) it('should set autoFocus when prop is true', () => { @@ -232,7 +236,8 @@ describe('SearchBox', () => { const input = screen.getByRole('textbox') // autoFocus is a boolean attribute that React handles specially - expect(input).toBeInTheDocument() + // autoFocus is a boolean attribute that React handles specially + expect(input)!.toBeInTheDocument() }) }) @@ -264,7 +269,7 @@ describe('SearchBox', () => { const buttons = screen.getAllByRole('button') // Find the clear button (the one in the search area) const clearButton = buttons[buttons.length - 1] - fireEvent.click(clearButton) + fireEvent.click(clearButton!) expect(onSearchChange).toHaveBeenCalledWith('') }) @@ -282,7 +287,7 @@ describe('SearchBox', () => { const buttons = screen.getAllByRole('button') // First button should be the clear button in non-marketplace mode - fireEvent.click(buttons[0]) + fireEvent.click(buttons[0]!) expect(onSearchChange).toHaveBeenCalledWith('') }) @@ -356,7 +361,7 @@ describe('SearchBox', () => { , ) - expect(container.querySelector('.custom-wrapper-class')).toBeInTheDocument() + expect(container.querySelector('.custom-wrapper-class'))!.toBeInTheDocument() }) it('should apply inputClassName correctly', () => { @@ -364,19 +369,19 @@ describe('SearchBox', () => { , ) - expect(container.querySelector('.custom-input-class')).toBeInTheDocument() + expect(container.querySelector('.custom-input-class'))!.toBeInTheDocument() }) it('should handle empty placeholder', () => { render() - expect(screen.getByRole('textbox')).toHaveAttribute('placeholder', '') + expect(screen.getByRole('textbox'))!.toHaveAttribute('placeholder', '') }) it('should use default placeholder when not provided', () => { render() - expect(screen.getByRole('textbox')).toHaveAttribute('placeholder', '') + expect(screen.getByRole('textbox'))!.toHaveAttribute('placeholder', '') }) }) @@ -387,14 +392,14 @@ describe('SearchBox', () => { it('should handle empty search value', () => { render() - expect(screen.getByRole('textbox')).toBeInTheDocument() - expect(screen.getByRole('textbox')).toHaveValue('') + expect(screen.getByRole('textbox'))!.toBeInTheDocument() + expect(screen.getByRole('textbox'))!.toHaveValue('') }) it('should handle empty tags array', () => { render() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-elem'))!.toBeInTheDocument() }) it('should handle special characters in search', () => { @@ -411,7 +416,7 @@ describe('SearchBox', () => { const longString = 'a'.repeat(1000) render() - expect(screen.getByDisplayValue(longString)).toBeInTheDocument() + expect(screen.getByDisplayValue(longString))!.toBeInTheDocument() }) it('should handle whitespace-only search', () => { @@ -439,20 +444,21 @@ describe('SearchBoxWrapper', () => { it('should render without crashing', () => { render() - expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('textbox'))!.toBeInTheDocument() }) it('should render in marketplace mode', () => { const { container } = render() - expect(container.querySelector('.rounded-xl')).toBeInTheDocument() + expect(container.querySelector('.rounded-xl'))!.toBeInTheDocument() }) it('should apply correct wrapper classes', () => { const { container } = render() // Check for z-11 class from wrapper - expect(container.querySelector('.z-11')).toBeInTheDocument() + // Check for z-11 class from wrapper + expect(container.querySelector('.z-11'))!.toBeInTheDocument() }) }) @@ -471,7 +477,7 @@ describe('SearchBoxWrapper', () => { it('should use translation for placeholder', () => { render() - expect(screen.getByPlaceholderText('Search plugins')).toBeInTheDocument() + expect(screen.getByPlaceholderText('Search plugins'))!.toBeInTheDocument() }) }) }) @@ -496,13 +502,13 @@ describe('MarketplaceTrigger', () => { it('should render without crashing', () => { render() - expect(screen.getByText('All Tags')).toBeInTheDocument() + expect(screen.getByText('All Tags'))!.toBeInTheDocument() }) it('should show "All Tags" when no tags selected', () => { render() - expect(screen.getByText('All Tags')).toBeInTheDocument() + expect(screen.getByText('All Tags'))!.toBeInTheDocument() }) it('should show arrow down icon when no tags selected', () => { @@ -511,7 +517,8 @@ describe('MarketplaceTrigger', () => { ) // Arrow down icon should be present - expect(container.querySelector('.size-4')).toBeInTheDocument() + // Arrow down icon should be present + expect(container.querySelector('.size-4'))!.toBeInTheDocument() }) }) @@ -525,7 +532,7 @@ describe('MarketplaceTrigger', () => { />, ) - expect(screen.getByText('Agent')).toBeInTheDocument() + expect(screen.getByText('Agent'))!.toBeInTheDocument() }) it('should show multiple tag labels separated by comma', () => { @@ -537,7 +544,7 @@ describe('MarketplaceTrigger', () => { />, ) - expect(screen.getByText('Agent,RAG')).toBeInTheDocument() + expect(screen.getByText('Agent,RAG'))!.toBeInTheDocument() }) it('should show +N indicator when more than 2 tags selected', () => { @@ -549,7 +556,7 @@ describe('MarketplaceTrigger', () => { />, ) - expect(screen.getByText('+2')).toBeInTheDocument() + expect(screen.getByText('+2'))!.toBeInTheDocument() }) it('should only show first 2 tags in label', () => { @@ -561,7 +568,7 @@ describe('MarketplaceTrigger', () => { />, ) - expect(screen.getByText('Agent,RAG')).toBeInTheDocument() + expect(screen.getByText('Agent,RAG'))!.toBeInTheDocument() expect(screen.queryByText('Search')).not.toBeInTheDocument() }) }) @@ -577,7 +584,8 @@ describe('MarketplaceTrigger', () => { ) // RiCloseCircleFill icon should be present - expect(container.querySelector('.text-text-quaternary')).toBeInTheDocument() + // RiCloseCircleFill icon should be present + expect(container.querySelector('.text-text-quaternary'))!.toBeInTheDocument() }) it('should not show clear button when no tags selected', () => { @@ -585,6 +593,37 @@ describe('MarketplaceTrigger', () => { , ) + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present + // Clear button should not be present // Clear button should not be present expect(container.querySelector('.text-text-quaternary')).not.toBeInTheDocument() }) @@ -614,7 +653,7 @@ describe('MarketplaceTrigger', () => { , ) - expect(container.querySelector('.bg-state-base-hover')).toBeInTheDocument() + expect(container.querySelector('.bg-state-base-hover'))!.toBeInTheDocument() }) it('should apply border styling when tags are selected', () => { @@ -626,7 +665,7 @@ describe('MarketplaceTrigger', () => { />, ) - expect(container.querySelector('.border-components-button-secondary-border')).toBeInTheDocument() + expect(container.querySelector('.border-components-button-secondary-border'))!.toBeInTheDocument() }) }) @@ -636,7 +675,7 @@ describe('MarketplaceTrigger', () => { , ) - expect(container).toBeInTheDocument() + expect(container)!.toBeInTheDocument() }) }) }) @@ -661,13 +700,13 @@ describe('ToolSelectorTrigger', () => { it('should render without crashing', () => { const { container } = render() - expect(container).toBeInTheDocument() + expect(container)!.toBeInTheDocument() }) it('should render price tag icon', () => { const { container } = render() - expect(container.querySelector('.size-4')).toBeInTheDocument() + expect(container.querySelector('.size-4'))!.toBeInTheDocument() }) }) @@ -681,7 +720,7 @@ describe('ToolSelectorTrigger', () => { />, ) - expect(screen.getByText('Agent')).toBeInTheDocument() + expect(screen.getByText('Agent'))!.toBeInTheDocument() }) it('should show multiple tag labels separated by comma', () => { @@ -693,7 +732,7 @@ describe('ToolSelectorTrigger', () => { />, ) - expect(screen.getByText('Agent,RAG')).toBeInTheDocument() + expect(screen.getByText('Agent,RAG'))!.toBeInTheDocument() }) it('should show +N indicator when more than 2 tags selected', () => { @@ -705,7 +744,7 @@ describe('ToolSelectorTrigger', () => { />, ) - expect(screen.getByText('+2')).toBeInTheDocument() + expect(screen.getByText('+2'))!.toBeInTheDocument() }) it('should not show tag labels when no tags selected', () => { @@ -725,7 +764,7 @@ describe('ToolSelectorTrigger', () => { />, ) - expect(container.querySelector('.text-text-quaternary')).toBeInTheDocument() + expect(container.querySelector('.text-text-quaternary'))!.toBeInTheDocument() }) it('should not show clear button when no tags selected', () => { @@ -785,7 +824,7 @@ describe('ToolSelectorTrigger', () => { , ) - expect(container.querySelector('.bg-state-base-hover')).toBeInTheDocument() + expect(container.querySelector('.bg-state-base-hover'))!.toBeInTheDocument() }) it('should apply border styling when tags are selected', () => { @@ -797,7 +836,7 @@ describe('ToolSelectorTrigger', () => { />, ) - expect(container.querySelector('.border-components-button-secondary-border')).toBeInTheDocument() + expect(container.querySelector('.border-components-button-secondary-border'))!.toBeInTheDocument() }) it('should not apply hover styling when open but has tags', () => { @@ -811,7 +850,8 @@ describe('ToolSelectorTrigger', () => { ) // Should have border styling, not hover - expect(container.querySelector('.border-components-button-secondary-border')).toBeInTheDocument() + // Should have border styling, not hover + expect(container.querySelector('.border-components-button-secondary-border'))!.toBeInTheDocument() }) }) @@ -826,7 +866,7 @@ describe('ToolSelectorTrigger', () => { />, ) - expect(screen.getByText('Agent')).toBeInTheDocument() + expect(screen.getByText('Agent'))!.toBeInTheDocument() }) }) }) @@ -854,7 +894,7 @@ describe('TagsFilter', () => { />, ) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-elem'))!.toBeInTheDocument() }) it('should pass usedInMarketplace prop to TagsFilter', () => { @@ -869,7 +909,8 @@ describe('TagsFilter', () => { ) // MarketplaceTrigger should show "All Tags" - expect(screen.getByText('All Tags')).toBeInTheDocument() + // MarketplaceTrigger should show "All Tags" + expect(screen.getByText('All Tags'))!.toBeInTheDocument() }) it('should show selected tags count in TagsFilter trigger', () => { @@ -883,7 +924,7 @@ describe('TagsFilter', () => { />, ) - expect(screen.getByText('+1')).toBeInTheDocument() + expect(screen.getByText('+1'))!.toBeInTheDocument() }) }) @@ -902,7 +943,7 @@ describe('TagsFilter', () => { fireEvent.click(trigger) await waitFor(() => { - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) }) @@ -921,7 +962,7 @@ describe('TagsFilter', () => { // Open fireEvent.click(trigger) await waitFor(() => { - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) // Close @@ -947,8 +988,8 @@ describe('TagsFilter', () => { fireEvent.click(trigger) await waitFor(() => { - expect(screen.getByText('Agent')).toBeInTheDocument() - expect(screen.getByText('RAG')).toBeInTheDocument() + expect(screen.getByText('Agent'))!.toBeInTheDocument() + expect(screen.getByText('RAG'))!.toBeInTheDocument() }) }) @@ -967,7 +1008,7 @@ describe('TagsFilter', () => { fireEvent.click(trigger) await waitFor(() => { - expect(screen.getByText('Agent')).toBeInTheDocument() + expect(screen.getByText('Agent'))!.toBeInTheDocument() }) const agentOption = screen.getByText('Agent') @@ -1018,7 +1059,7 @@ describe('TagsFilter', () => { fireEvent.click(trigger) await waitFor(() => { - expect(screen.getByText('RAG')).toBeInTheDocument() + expect(screen.getByText('RAG'))!.toBeInTheDocument() }) const ragOption = screen.getByText('RAG') @@ -1061,7 +1102,7 @@ describe('TagsFilter', () => { fireEvent.click(trigger) await waitFor(() => { - expect(screen.getByText('Agent')).toBeInTheDocument() + expect(screen.getByText('Agent'))!.toBeInTheDocument() }) const inputs = screen.getAllByRole('textbox') @@ -1071,7 +1112,7 @@ describe('TagsFilter', () => { if (searchInput) { fireEvent.change(searchInput, { target: { value: 'agent' } }) - expect(screen.getByText('Agent')).toBeInTheDocument() + expect(screen.getByText('Agent'))!.toBeInTheDocument() } }) }) @@ -1097,7 +1138,8 @@ describe('TagsFilter', () => { }) // Verify dropdown content is rendered - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + // Verify dropdown content is rendered + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) it('should render tag options when dropdown is open', async () => { @@ -1114,13 +1156,14 @@ describe('TagsFilter', () => { fireEvent.click(trigger) await waitFor(() => { - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) // When no tags selected, these should appear once each in dropdown - expect(screen.getByText('Agent')).toBeInTheDocument() - expect(screen.getByText('RAG')).toBeInTheDocument() - expect(screen.getByText('Search')).toBeInTheDocument() + // When no tags selected, these should appear once each in dropdown + expect(screen.getByText('Agent'))!.toBeInTheDocument() + expect(screen.getByText('RAG'))!.toBeInTheDocument() + expect(screen.getByText('Search'))!.toBeInTheDocument() }) }) }) @@ -1146,8 +1189,8 @@ describe('Accessibility', () => { ) const input = screen.getByRole('textbox') - expect(input).toBeInTheDocument() - expect(input).toHaveAttribute('placeholder', 'Search plugins') + expect(input)!.toBeInTheDocument() + expect(input)!.toHaveAttribute('placeholder', 'Search plugins') }) it('should have clickable tag options in dropdown', async () => { @@ -1156,7 +1199,7 @@ describe('Accessibility', () => { fireEvent.click(screen.getByTestId('portal-trigger')) await waitFor(() => { - expect(screen.getByText('Agent')).toBeInTheDocument() + expect(screen.getByText('Agent'))!.toBeInTheDocument() }) }) }) @@ -1192,7 +1235,7 @@ describe('Combined Workflows', () => { fireEvent.click(trigger) await waitFor(() => { - expect(screen.getByText('Agent')).toBeInTheDocument() + expect(screen.getByText('Agent'))!.toBeInTheDocument() }) const agentOption = screen.getByText('Agent') @@ -1217,9 +1260,9 @@ describe('Combined Workflows', () => { />, ) - expect(screen.getByDisplayValue('test')).toBeInTheDocument() - expect(screen.getByText('Agent,RAG')).toBeInTheDocument() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByDisplayValue('test'))!.toBeInTheDocument() + expect(screen.getByText('Agent,RAG'))!.toBeInTheDocument() + expect(screen.getByTestId('portal-elem'))!.toBeInTheDocument() }) it('should handle prop changes correctly', () => { @@ -1234,7 +1277,7 @@ describe('Combined Workflows', () => { />, ) - expect(screen.getByDisplayValue('initial')).toBeInTheDocument() + expect(screen.getByDisplayValue('initial'))!.toBeInTheDocument() rerender( { />, ) - expect(screen.getByDisplayValue('updated')).toBeInTheDocument() + expect(screen.getByDisplayValue('updated'))!.toBeInTheDocument() }) }) diff --git a/web/app/components/plugins/marketplace/sort-dropdown/__tests__/index.spec.tsx b/web/app/components/plugins/marketplace/sort-dropdown/__tests__/index.spec.tsx index 4d93726c4c..990bb321de 100644 --- a/web/app/components/plugins/marketplace/sort-dropdown/__tests__/index.spec.tsx +++ b/web/app/components/plugins/marketplace/sort-dropdown/__tests__/index.spec.tsx @@ -1,15 +1,12 @@ -import { fireEvent, render, screen, within } from '@testing-library/react' +import type { + MouseEventHandler, + ReactNode, +} from 'react' +import { render, screen, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import { beforeEach, describe, expect, it, vi } from 'vitest' import SortDropdown from '../index' -// ================================ -// Mock external dependencies only -// ================================ - -// Mock i18n translation hook const mockTranslation = vi.fn((key: string, options?: { ns?: string }) => { - // Build full key with namespace prefix if provided const fullKey = options?.ns ? `${options.ns}.${key}` : key const translations: Record = { 'plugin.marketplace.sortBy': 'Sort by', @@ -27,7 +24,6 @@ vi.mock('#i18n', () => ({ }), })) -// Mock marketplace atoms with controllable values let mockSort: { sortBy: string, sortOrder: string } = { sortBy: 'install_count', sortOrder: 'DESC' } const mockHandleSortChange = vi.fn() @@ -35,664 +31,123 @@ vi.mock('../../atoms', () => ({ useMarketplaceSort: () => [mockSort, mockHandleSortChange], })) -// Mock portal component with controllable open state -let mockPortalOpenState = false +vi.mock('@/app/components/base/ui/dropdown-menu', async () => { + const React = await import('react') + const DropdownMenuContext = React.createContext<{ open: boolean, setOpen: (open: boolean) => void } | null>(null) -vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ - PortalToFollowElem: ({ children, open, onOpenChange: _onOpenChange }: { - children: React.ReactNode - open: boolean - onOpenChange: (open: boolean) => void - }) => { - mockPortalOpenState = open - return ( -
- {children} -
- ) - }, - PortalToFollowElemTrigger: ({ children, onClick }: { - children: React.ReactNode - onClick: () => void - }) => ( -
- {children} -
- ), - PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => { - // Match actual behavior: only render when portal is open - if (!mockPortalOpenState) - return null - return
{children}
- }, -})) + const useDropdownMenuContext = () => { + const context = React.use(DropdownMenuContext) + if (!context) + throw new Error('DropdownMenu components must be wrapped in DropdownMenu') + return context + } -// ================================ -// Test Factory Functions -// ================================ + return { + DropdownMenu: ({ children, open, onOpenChange }: { children: ReactNode, open: boolean, onOpenChange?: (open: boolean) => void }) => ( + +
+ {children} +
+
+ ), + DropdownMenuTrigger: ({ children, className }: { children: ReactNode, className?: string }) => { + const { open, setOpen } = useDropdownMenuContext() + return ( + + ) + }, + DropdownMenuContent: ({ children }: { children: ReactNode }) => { + const { open } = useDropdownMenuContext() + return open ?
{children}
: null + }, + DropdownMenuItem: ({ + children, + onClick, + className, + }: { + children: ReactNode + onClick?: MouseEventHandler + className?: string + }) => { + const { setOpen } = useDropdownMenuContext() + return ( + + ) + }, + } +}) -type SortOption = { - value: string - order: string - text: string -} - -const createSortOptions = (): SortOption[] => [ - { value: 'install_count', order: 'DESC', text: 'Most Popular' }, - { value: 'version_updated_at', order: 'DESC', text: 'Recently Updated' }, - { value: 'created_at', order: 'DESC', text: 'Newly Released' }, - { value: 'created_at', order: 'ASC', text: 'First Released' }, -] - -// ================================ -// SortDropdown Component Tests -// ================================ describe('SortDropdown', () => { beforeEach(() => { vi.clearAllMocks() mockSort = { sortBy: 'install_count', sortOrder: 'DESC' } - mockPortalOpenState = false }) - // ================================ - // Rendering Tests - // ================================ - describe('Rendering', () => { - it('should render without crashing', () => { - render() + it('renders the selected sort option in the trigger', () => { + render() - expect(screen.getByTestId('portal-wrapper')).toBeInTheDocument() - }) - - it('should render sort by label', () => { - render() - - expect(screen.getByText('Sort by')).toBeInTheDocument() - }) - - it('should render selected option text', () => { - render() - - expect(screen.getByText('Most Popular')).toBeInTheDocument() - }) - - it('should render arrow down icon', () => { - const { container } = render() - - const arrowIcon = container.querySelector('.h-4.w-4.text-text-tertiary') - expect(arrowIcon).toBeInTheDocument() - }) - - it('should render trigger element with correct styles', () => { - const { container } = render() - - const trigger = container.querySelector('.cursor-pointer') - expect(trigger).toBeInTheDocument() - expect(trigger).toHaveClass('h-8', 'rounded-lg', 'bg-state-base-hover-alt') - }) - - it('should not render dropdown content when closed', () => { - render() - - expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument() - }) + const trigger = screen.getByTestId('dropdown-trigger') + expect(within(trigger).getByText('Sort by')).toBeInTheDocument() + expect(within(trigger).getByText('Most Popular')).toBeInTheDocument() }) - // ================================ - // State Management Tests - // ================================ - describe('State Management', () => { - it('should initialize with closed state', () => { - render() + it('falls back to the default option when the current sort is invalid', () => { + mockSort = { sortBy: 'unknown', sortOrder: 'ASC' } - const wrapper = screen.getByTestId('portal-wrapper') - expect(wrapper).toHaveAttribute('data-open', 'false') - }) + render() - it('should display correct selected option for install_count DESC', () => { - mockSort = { sortBy: 'install_count', sortOrder: 'DESC' } - render() - - expect(screen.getByText('Most Popular')).toBeInTheDocument() - }) - - it('should display correct selected option for version_updated_at DESC', () => { - mockSort = { sortBy: 'version_updated_at', sortOrder: 'DESC' } - render() - - expect(screen.getByText('Recently Updated')).toBeInTheDocument() - }) - - it('should display correct selected option for created_at DESC', () => { - mockSort = { sortBy: 'created_at', sortOrder: 'DESC' } - render() - - expect(screen.getByText('Newly Released')).toBeInTheDocument() - }) - - it('should display correct selected option for created_at ASC', () => { - mockSort = { sortBy: 'created_at', sortOrder: 'ASC' } - render() - - expect(screen.getByText('First Released')).toBeInTheDocument() - }) - - it('should toggle open state when trigger clicked', () => { - render() - - const trigger = screen.getByTestId('portal-trigger') - fireEvent.click(trigger) - - // After click, portal content should be visible - expect(screen.getByTestId('portal-content')).toBeInTheDocument() - }) - - it('should close dropdown when trigger clicked again', () => { - render() - - const trigger = screen.getByTestId('portal-trigger') - - // Open - fireEvent.click(trigger) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() - - // Close - fireEvent.click(trigger) - expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument() - }) + expect(screen.getByText('Most Popular')).toBeInTheDocument() }) - // ================================ - // User Interactions Tests - // ================================ - describe('User Interactions', () => { - it('should open dropdown on trigger click', () => { - render() + it('opens the menu and renders all sort options', async () => { + const user = userEvent.setup() + render() - const trigger = screen.getByTestId('portal-trigger') - fireEvent.click(trigger) + await user.click(screen.getByTestId('dropdown-trigger')) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() - }) - - it('should render all sort options when open', () => { - render() - - // Open dropdown - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - expect(within(content).getByText('Most Popular')).toBeInTheDocument() - expect(within(content).getByText('Recently Updated')).toBeInTheDocument() - expect(within(content).getByText('Newly Released')).toBeInTheDocument() - expect(within(content).getByText('First Released')).toBeInTheDocument() - }) - - it('should call handleSortChange when option clicked', () => { - render() - - // Open dropdown - fireEvent.click(screen.getByTestId('portal-trigger')) - - // Click on "Recently Updated" - const content = screen.getByTestId('portal-content') - fireEvent.click(within(content).getByText('Recently Updated')) - - expect(mockHandleSortChange).toHaveBeenCalledWith({ - sortBy: 'version_updated_at', - sortOrder: 'DESC', - }) - }) - - it('should call handleSortChange with correct params for Most Popular', () => { - mockSort = { sortBy: 'created_at', sortOrder: 'DESC' } - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - fireEvent.click(within(content).getByText('Most Popular')) - - expect(mockHandleSortChange).toHaveBeenCalledWith({ - sortBy: 'install_count', - sortOrder: 'DESC', - }) - }) - - it('should call handleSortChange with correct params for Newly Released', () => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - fireEvent.click(within(content).getByText('Newly Released')) - - expect(mockHandleSortChange).toHaveBeenCalledWith({ - sortBy: 'created_at', - sortOrder: 'DESC', - }) - }) - - it('should call handleSortChange with correct params for First Released', () => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - fireEvent.click(within(content).getByText('First Released')) - - expect(mockHandleSortChange).toHaveBeenCalledWith({ - sortBy: 'created_at', - sortOrder: 'ASC', - }) - }) - - it('should allow selecting currently selected option', () => { - mockSort = { sortBy: 'install_count', sortOrder: 'DESC' } - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - fireEvent.click(within(content).getByText('Most Popular')) - - expect(mockHandleSortChange).toHaveBeenCalledWith({ - sortBy: 'install_count', - sortOrder: 'DESC', - }) - }) - - it('should support userEvent for trigger click', async () => { - const user = userEvent.setup() - render() - - const trigger = screen.getByTestId('portal-trigger') - await user.click(trigger) - - expect(screen.getByTestId('portal-content')).toBeInTheDocument() - }) + const content = screen.getByTestId('dropdown-content') + expect(within(content).getByText('Most Popular')).toBeInTheDocument() + expect(within(content).getByText('Recently Updated')).toBeInTheDocument() + expect(within(content).getByText('Newly Released')).toBeInTheDocument() + expect(within(content).getByText('First Released')).toBeInTheDocument() }) - // ================================ - // Check Icon Tests - // ================================ - describe('Check Icon', () => { - it('should show check icon for selected option', () => { - mockSort = { sortBy: 'install_count', sortOrder: 'DESC' } - const { container } = render() + it('shows a check icon for the currently selected option', async () => { + const user = userEvent.setup() + const { container } = render() - // Open dropdown - fireEvent.click(screen.getByTestId('portal-trigger')) + await user.click(screen.getByTestId('dropdown-trigger')) - // Check icon should be present in the dropdown - const checkIcon = container.querySelector('.text-text-accent') - expect(checkIcon).toBeInTheDocument() - }) - - it('should show check icon only for matching sortBy AND sortOrder', () => { - mockSort = { sortBy: 'created_at', sortOrder: 'DESC' } - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - const options = content.querySelectorAll('.cursor-pointer') - - // "Newly Released" (created_at DESC) should have check icon - // "First Released" (created_at ASC) should NOT have check icon - expect(options.length).toBe(4) - }) - - it('should not show check icon for different sortOrder with same sortBy', () => { - mockSort = { sortBy: 'created_at', sortOrder: 'DESC' } - const { container } = render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - // Only one check icon should be visible (for Newly Released, not First Released) - const checkIcons = container.querySelectorAll('.text-text-accent') - expect(checkIcons.length).toBe(1) - }) + expect(container.querySelector('.i-ri-check-line')).toBeInTheDocument() }) - // ================================ - // Dropdown Options Structure Tests - // ================================ - describe('Dropdown Options Structure', () => { - const sortOptions = createSortOptions() + it('updates the sort and closes the menu when an option is selected', async () => { + const user = userEvent.setup() + render() - it('should render 4 sort options', () => { - render() + await user.click(screen.getByTestId('dropdown-trigger')) + await user.click(screen.getByText('Recently Updated')) - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - const options = content.querySelectorAll('.cursor-pointer') - expect(options.length).toBe(4) + expect(mockHandleSortChange).toHaveBeenCalledWith({ + sortBy: 'version_updated_at', + sortOrder: 'DESC', }) - - it.each(sortOptions)('should render option: $text', ({ text }) => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - expect(within(content).getByText(text)).toBeInTheDocument() - }) - - it('should render options with unique keys', () => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - const options = content.querySelectorAll('.cursor-pointer') - - // All options should be rendered (no key conflicts) - expect(options.length).toBe(4) - }) - - it('should render dropdown container with correct styles', () => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - const container = content.firstChild as HTMLElement - expect(container).toHaveClass('rounded-xl', 'shadow-lg') - }) - - it('should render option items with hover styles', () => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - const option = content.querySelector('.cursor-pointer') - expect(option).toHaveClass('hover:bg-components-panel-on-panel-item-bg-hover') - }) - }) - - // ================================ - // Edge Cases Tests - // ================================ - describe('Edge Cases', () => { - // The component falls back to the first option (Most Popular) when sort values are invalid - - it('should fallback to default option when sortBy is unknown', () => { - mockSort = { sortBy: 'unknown_field', sortOrder: 'DESC' } - - render() - - // Should fallback to first option "Most Popular" - expect(screen.getByText('Most Popular')).toBeInTheDocument() - }) - - it('should fallback to default option when sortBy is empty', () => { - mockSort = { sortBy: '', sortOrder: 'DESC' } - - render() - - expect(screen.getByText('Most Popular')).toBeInTheDocument() - }) - - it('should fallback to default option when sortOrder is unknown', () => { - mockSort = { sortBy: 'install_count', sortOrder: 'UNKNOWN' } - - render() - - expect(screen.getByText('Most Popular')).toBeInTheDocument() - }) - - it('should render correctly when handleSortChange is a no-op', () => { - mockHandleSortChange.mockImplementation(() => {}) - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - fireEvent.click(within(content).getByText('Recently Updated')) - - expect(mockHandleSortChange).toHaveBeenCalled() - }) - - it('should handle rapid toggle clicks', () => { - render() - - const trigger = screen.getByTestId('portal-trigger') - - // Rapid clicks - fireEvent.click(trigger) - fireEvent.click(trigger) - fireEvent.click(trigger) - - // Final state should be open (odd number of clicks) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() - }) - - it('should handle multiple option selections', () => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - - // Click multiple options - fireEvent.click(within(content).getByText('Recently Updated')) - fireEvent.click(within(content).getByText('Newly Released')) - fireEvent.click(within(content).getByText('First Released')) - - expect(mockHandleSortChange).toHaveBeenCalledTimes(3) - }) - }) - - // ================================ - // Context Integration Tests - // ================================ - describe('Context Integration', () => { - it('should read sort value from context', () => { - mockSort = { sortBy: 'version_updated_at', sortOrder: 'DESC' } - render() - - expect(screen.getByText('Recently Updated')).toBeInTheDocument() - }) - - it('should call context handleSortChange on selection', () => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - fireEvent.click(within(content).getByText('First Released')) - - expect(mockHandleSortChange).toHaveBeenCalledWith({ - sortBy: 'created_at', - sortOrder: 'ASC', - }) - }) - - it('should update display when context sort changes', () => { - const { rerender } = render() - - expect(screen.getByText('Most Popular')).toBeInTheDocument() - - // Simulate context change - mockSort = { sortBy: 'created_at', sortOrder: 'ASC' } - rerender() - - expect(screen.getByText('First Released')).toBeInTheDocument() - }) - - it('should use selector pattern correctly', () => { - render() - - // Component should have called useMarketplaceContext with selector functions - expect(screen.getByTestId('portal-wrapper')).toBeInTheDocument() - }) - }) - - // ================================ - // Accessibility Tests - // ================================ - describe('Accessibility', () => { - it('should have cursor pointer on trigger', () => { - const { container } = render() - - const trigger = container.querySelector('.cursor-pointer') - expect(trigger).toBeInTheDocument() - }) - - it('should have cursor pointer on options', () => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - const options = content.querySelectorAll('.cursor-pointer') - expect(options.length).toBeGreaterThan(0) - }) - - it('should have visible focus indicators via hover styles', () => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - const option = content.querySelector('.hover\\:bg-components-panel-on-panel-item-bg-hover') - expect(option).toBeInTheDocument() - }) - }) - - // ================================ - // Translation Tests - // ================================ - describe('Translations', () => { - it('should call translation for sortBy label', () => { - render() - - expect(mockTranslation).toHaveBeenCalledWith('marketplace.sortBy', { ns: 'plugin' }) - }) - - it('should call translation for all sort options', () => { - render() - - expect(mockTranslation).toHaveBeenCalledWith('marketplace.sortOption.mostPopular', { ns: 'plugin' }) - expect(mockTranslation).toHaveBeenCalledWith('marketplace.sortOption.recentlyUpdated', { ns: 'plugin' }) - expect(mockTranslation).toHaveBeenCalledWith('marketplace.sortOption.newlyReleased', { ns: 'plugin' }) - expect(mockTranslation).toHaveBeenCalledWith('marketplace.sortOption.firstReleased', { ns: 'plugin' }) - }) - }) - - // ================================ - // Portal Component Integration Tests - // ================================ - describe('Portal Component Integration', () => { - it('should pass open state to PortalToFollowElem', () => { - render() - - const wrapper = screen.getByTestId('portal-wrapper') - expect(wrapper).toHaveAttribute('data-open', 'false') - - fireEvent.click(screen.getByTestId('portal-trigger')) - - expect(wrapper).toHaveAttribute('data-open', 'true') - }) - - it('should render trigger content inside PortalToFollowElemTrigger', () => { - render() - - const trigger = screen.getByTestId('portal-trigger') - expect(within(trigger).getByText('Sort by')).toBeInTheDocument() - expect(within(trigger).getByText('Most Popular')).toBeInTheDocument() - }) - - it('should render options inside PortalToFollowElemContent', () => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - expect(within(content).getByText('Most Popular')).toBeInTheDocument() - }) - }) - - // ================================ - // Visual Style Tests - // ================================ - describe('Visual Styles', () => { - it('should apply correct trigger container styles', () => { - const { container } = render() - - const triggerDiv = container.querySelector('.flex.h-8.cursor-pointer.items-center.rounded-lg') - expect(triggerDiv).toBeInTheDocument() - }) - - it('should apply secondary text color to sort by label', () => { - const { container } = render() - - const label = container.querySelector('.text-text-secondary') - expect(label).toBeInTheDocument() - expect(label?.textContent).toBe('Sort by') - }) - - it('should apply primary text color to selected option', () => { - const { container } = render() - - const selected = container.querySelector('.text-text-primary.system-sm-medium') - expect(selected).toBeInTheDocument() - }) - - it('should apply tertiary text color to arrow icon', () => { - const { container } = render() - - const arrow = container.querySelector('.text-text-tertiary') - expect(arrow).toBeInTheDocument() - }) - - it('should apply accent text color to check icon when option selected', () => { - mockSort = { sortBy: 'install_count', sortOrder: 'DESC' } - const { container } = render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const checkIcon = container.querySelector('.text-text-accent') - expect(checkIcon).toBeInTheDocument() - }) - - it('should apply blur-sm backdrop to dropdown container', () => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - const container = content.querySelector('.backdrop-blur-xs') - expect(container).toBeInTheDocument() - }) - }) - - // ================================ - // All Sort Options Click Tests - // ================================ - describe('All Sort Options Click Handlers', () => { - const testCases = [ - { text: 'Most Popular', sortBy: 'install_count', sortOrder: 'DESC' }, - { text: 'Recently Updated', sortBy: 'version_updated_at', sortOrder: 'DESC' }, - { text: 'Newly Released', sortBy: 'created_at', sortOrder: 'DESC' }, - { text: 'First Released', sortBy: 'created_at', sortOrder: 'ASC' }, - ] - - it.each(testCases)( - 'should call handleSortChange with { sortBy: "$sortBy", sortOrder: "$sortOrder" } when clicking "$text"', - ({ text, sortBy, sortOrder }) => { - render() - - fireEvent.click(screen.getByTestId('portal-trigger')) - - const content = screen.getByTestId('portal-content') - fireEvent.click(within(content).getByText(text)) - - expect(mockHandleSortChange).toHaveBeenCalledWith({ sortBy, sortOrder }) - }, - ) + expect(screen.queryByTestId('dropdown-content')).not.toBeInTheDocument() }) }) diff --git a/web/app/components/plugins/marketplace/sort-dropdown/index.tsx b/web/app/components/plugins/marketplace/sort-dropdown/index.tsx index dddfab5402..a47143de02 100644 --- a/web/app/components/plugins/marketplace/sort-dropdown/index.tsx +++ b/web/app/components/plugins/marketplace/sort-dropdown/index.tsx @@ -1,15 +1,12 @@ 'use client' import { useTranslation } from '#i18n' -import { - RiArrowDownSLine, - RiCheckLine, -} from '@remixicon/react' import { useState } from 'react' import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from '@/app/components/base/ui/dropdown-menu' import { useMarketplaceSort } from '../atoms' const SortDropdown = () => { @@ -38,50 +35,44 @@ const SortDropdown = () => { ] const [sort, handleSortChange] = useMarketplaceSort() const [open, setOpen] = useState(false) - const selectedOption = options.find(option => option.value === sort.sortBy && option.order === sort.sortOrder) ?? options[0] + const selectedOption = options.find(option => option.value === sort.sortBy && option.order === sort.sortOrder) ?? options[0]! return ( - - setOpen(v => !v)}> -
- - {t('marketplace.sortBy', { ns: 'plugin' })} - - - {selectedOption.text} - - -
-
- -
- { - options.map(option => ( -
handleSortChange({ sortBy: option.value, sortOrder: option.order })} - > - {option.text} - { - sort.sortBy === option.value && sort.sortOrder === option.order && ( - - ) - } -
- )) - } -
-
-
+ + + {t('marketplace.sortBy', { ns: 'plugin' })} + + + {selectedOption.text} + + + + + {options.map(option => ( + { + handleSortChange({ sortBy: option.value, sortOrder: option.order }) + setOpen(false) + }} + > + {option.text} + {sort.sortBy === option.value && sort.sortOrder === option.order && ( + + )} + + ))} + + ) } diff --git a/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx b/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx index 74ce8525a9..0eacbf3bd3 100644 --- a/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx @@ -297,13 +297,13 @@ describe('DetailHeader', () => { it('should render plugin title', () => { render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should render plugin icon with correct src', () => { render() - expect(screen.getByTestId('card-icon')).toBeInTheDocument() + expect(screen.getByTestId('card-icon'))!.toBeInTheDocument() }) it('should render icon with http url directly', () => { @@ -315,13 +315,13 @@ describe('DetailHeader', () => { }) render() - expect(screen.getByTestId('card-icon')).toHaveAttribute('data-src', 'https://example.com/icon.png') + expect(screen.getByTestId('card-icon'))!.toHaveAttribute('data-src', 'https://example.com/icon.png') }) it('should render description when not in readme view', () => { render() - expect(screen.getByTestId('description')).toBeInTheDocument() + expect(screen.getByTestId('description'))!.toBeInTheDocument() }) it('should not render description in readme view', () => { @@ -333,7 +333,7 @@ describe('DetailHeader', () => { it('should render verified badge when verified', () => { render() - expect(screen.getByTestId('verified-badge')).toBeInTheDocument() + expect(screen.getByTestId('verified-badge'))!.toBeInTheDocument() }) }) @@ -346,7 +346,8 @@ describe('DetailHeader', () => { render() // Badge component should render with the version - expect(screen.getByText('1.0.0')).toBeInTheDocument() + // Badge component should render with the version + expect(screen.getByText('1.0.0'))!.toBeInTheDocument() }) it('should not show new version indicator when versions match', () => { @@ -357,7 +358,8 @@ describe('DetailHeader', () => { render() // Badge component should render with the version - expect(screen.getByText('1.0.0')).toBeInTheDocument() + // Badge component should render with the version + expect(screen.getByText('1.0.0'))!.toBeInTheDocument() }) it('should show update button when new version is available', () => { @@ -367,7 +369,7 @@ describe('DetailHeader', () => { }) render() - expect(screen.getByText('plugin.detailPanel.operation.update')).toBeInTheDocument() + expect(screen.getByText('plugin.detailPanel.operation.update'))!.toBeInTheDocument() }) it('should show update button for GitHub source', () => { @@ -377,7 +379,7 @@ describe('DetailHeader', () => { }) render() - expect(screen.getByText('plugin.detailPanel.operation.update')).toBeInTheDocument() + expect(screen.getByText('plugin.detailPanel.operation.update'))!.toBeInTheDocument() }) }) @@ -393,7 +395,7 @@ describe('DetailHeader', () => { render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should render component when strategy is disabled', () => { @@ -407,7 +409,7 @@ describe('DetailHeader', () => { render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should enable auto upgrade for update_all mode', () => { @@ -422,7 +424,8 @@ describe('DetailHeader', () => { render() // Auto upgrade badge should be rendered - expect(screen.getByTestId('title')).toBeInTheDocument() + // Auto upgrade badge should be rendered + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should enable auto upgrade for partial mode when plugin is included', () => { @@ -436,7 +439,7 @@ describe('DetailHeader', () => { render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should not enable auto upgrade for partial mode when plugin is not included', () => { @@ -450,7 +453,7 @@ describe('DetailHeader', () => { render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should enable auto upgrade for exclude mode when plugin is not excluded', () => { @@ -464,7 +467,7 @@ describe('DetailHeader', () => { render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should not enable auto upgrade for exclude mode when plugin is excluded', () => { @@ -478,7 +481,7 @@ describe('DetailHeader', () => { render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should not enable auto upgrade for non-marketplace plugins', () => { @@ -496,7 +499,7 @@ describe('DetailHeader', () => { }) render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should not enable auto upgrade when marketplace feature is disabled', () => { @@ -512,7 +515,8 @@ describe('DetailHeader', () => { render() // Component should still render but auto upgrade should be disabled - expect(screen.getByTestId('title')).toBeInTheDocument() + // Component should still render but auto upgrade should be disabled + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) }) @@ -522,7 +526,7 @@ describe('DetailHeader', () => { // Find the close button (ActionButton with action-btn class) const actionButtons = screen.getAllByRole('button').filter(btn => btn.classList.contains('action-btn')) - fireEvent.click(actionButtons[actionButtons.length - 1]) + fireEvent.click(actionButtons[actionButtons.length - 1]!) expect(mockOnHide).toHaveBeenCalled() }) @@ -533,7 +537,7 @@ describe('DetailHeader', () => { const infoBtn = screen.getByTestId('info-btn') fireEvent.click(infoBtn) - expect(infoBtn).toBeInTheDocument() + expect(infoBtn)!.toBeInTheDocument() }) it('should have check version button available', () => { @@ -542,7 +546,7 @@ describe('DetailHeader', () => { const checkBtn = screen.getByTestId('check-version-btn') fireEvent.click(checkBtn) - expect(checkBtn).toBeInTheDocument() + expect(checkBtn)!.toBeInTheDocument() }) }) @@ -557,7 +561,7 @@ describe('DetailHeader', () => { const updateBtn = screen.getByText('plugin.detailPanel.operation.update') fireEvent.click(updateBtn) - expect(updateBtn).toBeInTheDocument() + expect(updateBtn)!.toBeInTheDocument() }) it('should have version picker select button', () => { @@ -566,7 +570,7 @@ describe('DetailHeader', () => { const selectBtn = screen.getByTestId('select-version-btn') fireEvent.click(selectBtn) - expect(selectBtn).toBeInTheDocument() + expect(selectBtn)!.toBeInTheDocument() }) it('should have downgrade button', () => { @@ -575,7 +579,7 @@ describe('DetailHeader', () => { const downgradeBtn = screen.getByTestId('select-downgrade-btn') fireEvent.click(downgradeBtn) - expect(downgradeBtn).toBeInTheDocument() + expect(downgradeBtn)!.toBeInTheDocument() }) }) @@ -651,7 +655,7 @@ describe('DetailHeader', () => { const removeBtn = screen.getByTestId('remove-btn') fireEvent.click(removeBtn) - expect(removeBtn).toBeInTheDocument() + expect(removeBtn)!.toBeInTheDocument() }) it('should have uninstallPlugin mock defined', () => { @@ -671,13 +675,13 @@ describe('DetailHeader', () => { }) render() - expect(screen.getByTestId('remove-btn')).toBeInTheDocument() + expect(screen.getByTestId('remove-btn'))!.toBeInTheDocument() }) it('should render correctly for tool plugin delete', () => { render() - expect(screen.getByTestId('remove-btn')).toBeInTheDocument() + expect(screen.getByTestId('remove-btn'))!.toBeInTheDocument() }) }) @@ -689,21 +693,21 @@ describe('DetailHeader', () => { }) render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should render local source icon', () => { const detail = createPluginDetail({ source: PluginSource.local }) render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should render debugging source icon', () => { const detail = createPluginDetail({ source: PluginSource.debugging }) render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should not render deprecation notice for non-marketplace source', () => { @@ -722,20 +726,20 @@ describe('DetailHeader', () => { }) render() - expect(screen.getByTestId('operation-dropdown')).toBeInTheDocument() + expect(screen.getByTestId('operation-dropdown'))!.toBeInTheDocument() }) it('should render marketplace source correctly', () => { render() - expect(screen.getByTestId('operation-dropdown')).toBeInTheDocument() + expect(screen.getByTestId('operation-dropdown'))!.toBeInTheDocument() }) it('should render local source correctly', () => { const detail = createPluginDetail({ source: PluginSource.local }) render() - expect(screen.getByTestId('operation-dropdown')).toBeInTheDocument() + expect(screen.getByTestId('operation-dropdown'))!.toBeInTheDocument() }) }) @@ -743,7 +747,7 @@ describe('DetailHeader', () => { it('should render plugin auth for tool category', () => { render() - expect(screen.getByTestId('plugin-auth')).toBeInTheDocument() + expect(screen.getByTestId('plugin-auth'))!.toBeInTheDocument() }) it('should not render plugin auth for non-tool category', () => { @@ -770,7 +774,7 @@ describe('DetailHeader', () => { const detail = createPluginDetail({ version: '' }) render() - expect(screen.getByTestId('title')).toBeInTheDocument() + expect(screen.getByTestId('title'))!.toBeInTheDocument() }) it('should handle plugin with name containing slash', () => { @@ -782,7 +786,7 @@ describe('DetailHeader', () => { }) render() - expect(screen.getByTestId('org-info')).toBeInTheDocument() + expect(screen.getByTestId('org-info'))!.toBeInTheDocument() }) it('should handle empty icon', () => { @@ -794,7 +798,7 @@ describe('DetailHeader', () => { }) render() - expect(screen.getByTestId('card-icon')).toHaveAttribute('data-src', '') + expect(screen.getByTestId('card-icon'))!.toHaveAttribute('data-src', '') }) }) @@ -805,7 +809,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByTestId('remove-btn')) await waitFor(() => { - expect(screen.getByRole('alertdialog')).toBeInTheDocument() + expect(screen.getByRole('alertdialog'))!.toBeInTheDocument() }) }) @@ -814,7 +818,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByTestId('remove-btn')) await waitFor(() => { - expect(screen.getByRole('alertdialog')).toBeInTheDocument() + expect(screen.getByRole('alertdialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) @@ -829,7 +833,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByTestId('remove-btn')) await waitFor(() => { - expect(screen.getByRole('alertdialog')).toBeInTheDocument() + expect(screen.getByRole('alertdialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) @@ -844,7 +848,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByTestId('remove-btn')) await waitFor(() => { - expect(screen.getByRole('alertdialog')).toBeInTheDocument() + expect(screen.getByRole('alertdialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) @@ -865,7 +869,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByTestId('remove-btn')) await waitFor(() => { - expect(screen.getByRole('alertdialog')).toBeInTheDocument() + expect(screen.getByRole('alertdialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) @@ -880,7 +884,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByTestId('remove-btn')) await waitFor(() => { - expect(screen.getByRole('alertdialog')).toBeInTheDocument() + expect(screen.getByRole('alertdialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) @@ -895,7 +899,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByTestId('remove-btn')) await waitFor(() => { - expect(screen.getByRole('alertdialog')).toBeInTheDocument() + expect(screen.getByRole('alertdialog'))!.toBeInTheDocument() }) fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) @@ -917,7 +921,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByText('plugin.detailPanel.operation.update')) await waitFor(() => { - expect(screen.getByTestId('update-modal')).toBeInTheDocument() + expect(screen.getByTestId('update-modal'))!.toBeInTheDocument() }) }) @@ -930,7 +934,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByText('plugin.detailPanel.operation.update')) await waitFor(() => { - expect(screen.getByTestId('update-modal')).toBeInTheDocument() + expect(screen.getByTestId('update-modal'))!.toBeInTheDocument() }) fireEvent.click(screen.getByTestId('update-modal-save')) @@ -949,7 +953,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByText('plugin.detailPanel.operation.update')) await waitFor(() => { - expect(screen.getByTestId('update-modal')).toBeInTheDocument() + expect(screen.getByTestId('update-modal'))!.toBeInTheDocument() }) fireEvent.click(screen.getByTestId('update-modal-cancel')) @@ -967,7 +971,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByTestId('info-btn')) await waitFor(() => { - expect(screen.getByTestId('plugin-info')).toBeInTheDocument() + expect(screen.getByTestId('plugin-info'))!.toBeInTheDocument() }) }) @@ -976,7 +980,7 @@ describe('DetailHeader', () => { fireEvent.click(screen.getByTestId('info-btn')) await waitFor(() => { - expect(screen.getByTestId('plugin-info')).toBeInTheDocument() + expect(screen.getByTestId('plugin-info'))!.toBeInTheDocument() }) fireEvent.click(screen.getByTestId('plugin-info-close')) @@ -993,7 +997,7 @@ describe('DetailHeader', () => { }) render() - expect(screen.getByTestId('info-btn')).toBeInTheDocument() + expect(screen.getByTestId('info-btn'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/index.spec.tsx index 4dd604a03e..f7dd1921e4 100644 --- a/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/__tests__/index.spec.tsx @@ -393,19 +393,20 @@ describe('AppTrigger', () => { it('should render placeholder when no app is selected', () => { render() // i18n mock returns key with namespace in dot format - expect(screen.getByText('app.appSelector.placeholder')).toBeInTheDocument() + // i18n mock returns key with namespace in dot format + expect(screen.getByText('app.appSelector.placeholder'))!.toBeInTheDocument() }) it('should render app details when app is selected', () => { const app = createMockApp({ name: 'My Test App' }) render() - expect(screen.getByText('My Test App')).toBeInTheDocument() + expect(screen.getByText('My Test App'))!.toBeInTheDocument() }) it('should apply open state styling', () => { const { container } = render() const trigger = container.querySelector('.bg-state-base-hover-alt') - expect(trigger).toBeInTheDocument() + expect(trigger)!.toBeInTheDocument() }) it('should render AppIcon when app is provided', () => { @@ -413,21 +414,21 @@ describe('AppTrigger', () => { const { container } = render() // AppIcon renders with a specific class when app is provided const iconContainer = container.querySelector('.mr-2') - expect(iconContainer).toBeInTheDocument() + expect(iconContainer)!.toBeInTheDocument() }) }) describe('Props', () => { it('should handle undefined appDetail gracefully', () => { render() - expect(screen.getByText('app.appSelector.placeholder')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.placeholder'))!.toBeInTheDocument() }) it('should display app name with title attribute', () => { const app = createMockApp({ name: 'Long App Name For Testing' }) render() const nameElement = screen.getByTitle('Long App Name For Testing') - expect(nameElement).toBeInTheDocument() + expect(nameElement)!.toBeInTheDocument() }) }) @@ -435,14 +436,14 @@ describe('AppTrigger', () => { it('should have correct base classes', () => { const { container } = render() const trigger = container.firstChild as HTMLElement - expect(trigger).toHaveClass('group', 'flex', 'cursor-pointer') + expect(trigger)!.toHaveClass('group', 'flex', 'cursor-pointer') }) it('should apply different padding when app is provided', () => { const app = createMockApp() const { container } = render() const trigger = container.firstChild as HTMLElement - expect(trigger).toHaveClass('py-1.5', 'pl-1.5') + expect(trigger)!.toHaveClass('py-1.5', 'pl-1.5') }) }) }) @@ -479,18 +480,18 @@ describe('AppPicker', () => { describe('Rendering', () => { it('should render trigger element', () => { render() - expect(screen.getByText('Select App')).toBeInTheDocument() + expect(screen.getByText('Select App'))!.toBeInTheDocument() }) it('should render app list when open', () => { render() - expect(screen.getByText('App 1')).toBeInTheDocument() - expect(screen.getByText('App 2')).toBeInTheDocument() + expect(screen.getByText('App 1'))!.toBeInTheDocument() + expect(screen.getByText('App 2'))!.toBeInTheDocument() }) it('should show loading indicator when isLoading is true', () => { render() - expect(screen.getByText('common.loading')).toBeInTheDocument() + expect(screen.getByText('common.loading'))!.toBeInTheDocument() }) it('should not render content when isShow is false', () => { @@ -538,31 +539,31 @@ describe('AppPicker', () => { it('should display correct app type for CHAT', () => { const apps = [createMockApp({ id: 'chat-app', name: 'Chat App', mode: AppModeEnum.CHAT })] render() - expect(screen.getByText('chat')).toBeInTheDocument() + expect(screen.getByText('chat'))!.toBeInTheDocument() }) it('should display correct app type for WORKFLOW', () => { const apps = [createMockApp({ id: 'workflow-app', name: 'Workflow App', mode: AppModeEnum.WORKFLOW })] render() - expect(screen.getByText('workflow')).toBeInTheDocument() + expect(screen.getByText('workflow'))!.toBeInTheDocument() }) it('should display correct app type for ADVANCED_CHAT', () => { const apps = [createMockApp({ id: 'chatflow-app', name: 'Chatflow App', mode: AppModeEnum.ADVANCED_CHAT })] render() - expect(screen.getByText('chatflow')).toBeInTheDocument() + expect(screen.getByText('chatflow'))!.toBeInTheDocument() }) it('should display correct app type for AGENT_CHAT', () => { const apps = [createMockApp({ id: 'agent-app', name: 'Agent App', mode: AppModeEnum.AGENT_CHAT })] render() - expect(screen.getByText('agent')).toBeInTheDocument() + expect(screen.getByText('agent'))!.toBeInTheDocument() }) it('should display correct app type for COMPLETION', () => { const apps = [createMockApp({ id: 'completion-app', name: 'Completion App', mode: AppModeEnum.COMPLETION })] render() - expect(screen.getByText('completion')).toBeInTheDocument() + expect(screen.getByText('completion'))!.toBeInTheDocument() }) }) @@ -575,7 +576,7 @@ describe('AppPicker', () => { it('should handle search text with value', () => { render() const input = screen.getByTestId('input') - expect(input).toHaveValue('test search') + expect(input)!.toHaveValue('test search') }) }) @@ -641,7 +642,8 @@ describe('AppPicker', () => { render() // The component should render without errors - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + // The component should render without errors + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle isShow toggle correctly', () => { @@ -654,7 +656,8 @@ describe('AppPicker', () => { rerender() // Should not crash - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + // Should not crash + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should setup intersection observer when isShow is true', () => { @@ -674,7 +677,8 @@ describe('AppPicker', () => { rerender() // Component should render without errors - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + // Component should render without errors + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should cleanup observer on component unmount', () => { @@ -691,7 +695,8 @@ describe('AppPicker', () => { triggerMutationObserver() // Component should still work correctly - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + // Component should still work correctly + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should not setup IntersectionObserver when observerTarget is null', () => { @@ -699,7 +704,8 @@ describe('AppPicker', () => { render() // The guard at line 84 should prevent setup - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + // The guard at line 84 should prevent setup + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should debounce onLoadMore calls using loadingRef', () => { @@ -798,8 +804,8 @@ describe('AppInputsForm', () => { { type: InputVarType.textInput, label: 'Name', variable: 'name', required: false }, ] render() - expect(screen.getByText('Name')).toBeInTheDocument() - expect(screen.getByPlaceholderText('Name')).toBeInTheDocument() + expect(screen.getByText('Name'))!.toBeInTheDocument() + expect(screen.getByPlaceholderText('Name'))!.toBeInTheDocument() }) it('should render number input field', () => { @@ -807,7 +813,7 @@ describe('AppInputsForm', () => { { type: InputVarType.number, label: 'Count', variable: 'count', required: false }, ] render() - expect(screen.getByText('Count')).toBeInTheDocument() + expect(screen.getByText('Count'))!.toBeInTheDocument() }) it('should render paragraph (textarea) field', () => { @@ -815,7 +821,7 @@ describe('AppInputsForm', () => { { type: InputVarType.paragraph, label: 'Description', variable: 'desc', required: false }, ] render() - expect(screen.getByText('Description')).toBeInTheDocument() + expect(screen.getByText('Description'))!.toBeInTheDocument() }) it('should render select field', () => { @@ -840,8 +846,8 @@ describe('AppInputsForm', () => { }, ] render() - expect(screen.getByText('Single File Upload')).toBeInTheDocument() - expect(screen.getByTestId('file-uploader')).toBeInTheDocument() + expect(screen.getByText('Single File Upload'))!.toBeInTheDocument() + expect(screen.getByTestId('file-uploader'))!.toBeInTheDocument() }) it('should render file uploader for single file with existing value', () => { @@ -859,7 +865,8 @@ describe('AppInputsForm', () => { ] render() // The file uploader should receive the existing file as an array - expect(screen.getByTestId('file-value')).toHaveTextContent(JSON.stringify([existingFile])) + // The file uploader should receive the existing file as an array + expect(screen.getByTestId('file-value'))!.toHaveTextContent(JSON.stringify([existingFile])) }) it('should render file uploader for multi files', () => { @@ -876,7 +883,7 @@ describe('AppInputsForm', () => { }, ] render() - expect(screen.getByText('Attachments')).toBeInTheDocument() + expect(screen.getByText('Attachments'))!.toBeInTheDocument() }) it('should show optional label for non-required fields', () => { @@ -884,7 +891,7 @@ describe('AppInputsForm', () => { { type: InputVarType.textInput, label: 'Name', variable: 'name', required: false }, ] render() - expect(screen.getByText('workflow.panel.optional')).toBeInTheDocument() + expect(screen.getByText('workflow.panel.optional'))!.toBeInTheDocument() }) it('should not show optional label for required fields', () => { @@ -1026,7 +1033,7 @@ describe('AppInputsForm', () => { render() const input = screen.getByPlaceholderText('Name') - expect(input).toHaveValue('existing') + expect(input)!.toHaveValue('existing') }) it('should handle empty string value', () => { @@ -1036,7 +1043,7 @@ describe('AppInputsForm', () => { render() const input = screen.getByPlaceholderText('Name') - expect(input).toHaveValue('') + expect(input)!.toHaveValue('') }) it('should handle undefined variable value', () => { @@ -1046,7 +1053,7 @@ describe('AppInputsForm', () => { render() const input = screen.getByPlaceholderText('Name') - expect(input).toHaveValue('') + expect(input)!.toHaveValue('') }) it('should handle multiple form fields', () => { @@ -1057,9 +1064,9 @@ describe('AppInputsForm', () => { ] render() - expect(screen.getByText('Name')).toBeInTheDocument() - expect(screen.getByText('Age')).toBeInTheDocument() - expect(screen.getByText('Bio')).toBeInTheDocument() + expect(screen.getByText('Name'))!.toBeInTheDocument() + expect(screen.getByText('Age'))!.toBeInTheDocument() + expect(screen.getByText('Bio'))!.toBeInTheDocument() }) it('should handle unknown form type gracefully', () => { @@ -1068,7 +1075,7 @@ describe('AppInputsForm', () => { ] // Should not throw error, just not render the field render() - expect(screen.getByText('Unknown')).toBeInTheDocument() + expect(screen.getByText('Unknown'))!.toBeInTheDocument() }) }) }) @@ -1093,18 +1100,49 @@ describe('AppInputsPanel', () => { describe('Rendering', () => { it('should render without crashing', () => { renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should show no params message when form schema is empty', () => { renderWithQueryClient() - expect(screen.getByText('app.appSelector.noParams')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.noParams'))!.toBeInTheDocument() }) it('should show loading state when app is loading', () => { mockAppDetailLoading = true renderWithQueryClient() // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered + // Loading component should be rendered expect(screen.queryByText('app.appSelector.params')).not.toBeInTheDocument() }) @@ -1119,19 +1157,19 @@ describe('AppInputsPanel', () => { describe('Props', () => { it('should handle undefined value', () => { renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should handle different app modes', () => { const workflowApp = createMockApp({ mode: AppModeEnum.WORKFLOW }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should handle advanced chat mode', () => { const advancedChatApp = createMockApp({ mode: AppModeEnum.ADVANCED_CHAT }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) }) @@ -1147,7 +1185,7 @@ describe('AppInputsPanel', () => { }, }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should generate schema for number input', () => { @@ -1161,7 +1199,7 @@ describe('AppInputsPanel', () => { }, }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should generate schema for checkbox input', () => { @@ -1175,7 +1213,7 @@ describe('AppInputsPanel', () => { }, }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should generate schema for select input', () => { @@ -1189,7 +1227,7 @@ describe('AppInputsPanel', () => { }, }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should generate schema for file-list input', () => { @@ -1203,7 +1241,7 @@ describe('AppInputsPanel', () => { }, }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should generate schema for file input', () => { @@ -1217,7 +1255,7 @@ describe('AppInputsPanel', () => { }, }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should generate schema for json_object input', () => { @@ -1231,7 +1269,7 @@ describe('AppInputsPanel', () => { }, }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should generate schema for text-input (default)', () => { @@ -1245,7 +1283,7 @@ describe('AppInputsPanel', () => { }, }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should filter external_data_tool items', () => { @@ -1260,7 +1298,7 @@ describe('AppInputsPanel', () => { }, }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) }) @@ -1283,7 +1321,7 @@ describe('AppInputsPanel', () => { } const workflowApp = createMockApp({ mode: AppModeEnum.WORKFLOW }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should generate schema for workflow with singleFile variable', () => { @@ -1304,7 +1342,7 @@ describe('AppInputsPanel', () => { } const workflowApp = createMockApp({ mode: AppModeEnum.WORKFLOW }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should generate schema for workflow with regular variable', () => { @@ -1325,7 +1363,7 @@ describe('AppInputsPanel', () => { } const workflowApp = createMockApp({ mode: AppModeEnum.WORKFLOW }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) }) @@ -1344,7 +1382,7 @@ describe('AppInputsPanel', () => { }) const completionApp = createMockApp({ mode: AppModeEnum.COMPLETION }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should add image upload schema for WORKFLOW mode with file upload enabled', () => { @@ -1364,7 +1402,7 @@ describe('AppInputsPanel', () => { } const workflowApp = createMockApp({ mode: AppModeEnum.WORKFLOW }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) }) @@ -1372,7 +1410,7 @@ describe('AppInputsPanel', () => { it('should call onFormChange when form is updated', () => { const onFormChange = vi.fn() renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) it('should call onFormChange with updated values when text input changes', () => { @@ -1424,7 +1462,7 @@ describe('AppInputsPanel', () => { , ) - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) }) @@ -1432,7 +1470,7 @@ describe('AppInputsPanel', () => { it('should return empty schema when currentApp is null', () => { mockAppDetailData = null renderWithQueryClient() - expect(screen.getByText('app.appSelector.noParams')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.noParams'))!.toBeInTheDocument() }) it('should handle workflow without start node', () => { @@ -1442,7 +1480,7 @@ describe('AppInputsPanel', () => { } const workflowApp = createMockApp({ mode: AppModeEnum.WORKFLOW }) renderWithQueryClient() - expect(screen.getByText('app.appSelector.params')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.params'))!.toBeInTheDocument() }) }) }) @@ -1477,12 +1515,12 @@ describe('AppSelector', () => { describe('Rendering', () => { it('should render without crashing', () => { renderWithQueryClient() - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should render trigger component', () => { renderWithQueryClient() - expect(screen.getByText('app.appSelector.placeholder')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.placeholder'))!.toBeInTheDocument() }) it('should show selected app info when value is provided', () => { @@ -1493,19 +1531,20 @@ describe('AppSelector', () => { />, ) // Should show the app trigger with app info - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + // Should show the app trigger with app info + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) }) describe('Props', () => { it('should handle different placement values', () => { renderWithQueryClient() - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle different offset values', () => { renderWithQueryClient() - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle disabled state', () => { @@ -1513,12 +1552,13 @@ describe('AppSelector', () => { const trigger = screen.getByTestId('portal-trigger') fireEvent.click(trigger) // Portal should remain closed when disabled - expect(screen.getByTestId('portal-to-follow-elem')).toHaveAttribute('data-open', 'false') + // Portal should remain closed when disabled + expect(screen.getByTestId('portal-to-follow-elem'))!.toHaveAttribute('data-open', 'false') }) it('should handle scope prop', () => { renderWithQueryClient() - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle value with inputs', () => { @@ -1528,7 +1568,7 @@ describe('AppSelector', () => { value={{ app_id: 'app-1', inputs: { name: 'test' }, files: [] }} />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle value with files', () => { @@ -1538,7 +1578,7 @@ describe('AppSelector', () => { value={{ app_id: 'app-1', inputs: {}, files: [{ id: 'file-1' }] }} />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) }) @@ -1547,10 +1587,11 @@ describe('AppSelector', () => { renderWithQueryClient() const trigger = screen.getAllByTestId('portal-trigger')[0] - fireEvent.click(trigger) + fireEvent.click(trigger!) // The portal state should update synchronously - get the first one (outer portal) - expect(screen.getAllByTestId('portal-to-follow-elem')[0]).toHaveAttribute('data-open', 'true') + // The portal state should update synchronously - get the first one (outer portal) + expect(screen.getAllByTestId('portal-to-follow-elem')[0])!.toHaveAttribute('data-open', 'true') }) it('should not toggle isShow when disabled', () => { @@ -1559,7 +1600,7 @@ describe('AppSelector', () => { const trigger = screen.getByTestId('portal-trigger') fireEvent.click(trigger) - expect(screen.getByTestId('portal-to-follow-elem')).toHaveAttribute('data-open', 'false') + expect(screen.getByTestId('portal-to-follow-elem'))!.toHaveAttribute('data-open', 'false') }) it('should manage search text state', () => { @@ -1569,7 +1610,8 @@ describe('AppSelector', () => { fireEvent.click(trigger) // Portal content should be visible after click - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + // Portal content should be visible after click + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) it('should render correctly during load more setup', () => { @@ -1579,7 +1621,8 @@ describe('AppSelector', () => { renderWithQueryClient() // Trigger should be rendered - expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + // Trigger should be rendered + expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument() }) }) @@ -1592,7 +1635,7 @@ describe('AppSelector', () => { // Open the portal fireEvent.click(screen.getByTestId('portal-trigger')) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) it('should call onSelect with correct value structure', () => { @@ -1606,7 +1649,8 @@ describe('AppSelector', () => { ) // The component should maintain the correct value structure - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + // The component should maintain the correct value structure + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should clear inputs when selecting different app', () => { @@ -1620,7 +1664,8 @@ describe('AppSelector', () => { ) // Component renders with existing value - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + // Component renders with existing value + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should preserve inputs when selecting same app', () => { @@ -1633,7 +1678,7 @@ describe('AppSelector', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) }) @@ -1647,7 +1692,7 @@ describe('AppSelector', () => { } renderWithQueryClient() - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should memoize currentAppInfo correctly', () => { @@ -1662,7 +1707,7 @@ describe('AppSelector', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should memoize formattedValue correctly', () => { @@ -1673,7 +1718,7 @@ describe('AppSelector', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should be wrapped with React.memo', () => { @@ -1690,7 +1735,7 @@ describe('AppSelector', () => { , ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) }) @@ -1698,7 +1743,7 @@ describe('AppSelector', () => { it('should handle load more when hasMore is true', async () => { mockHasNextPage = true renderWithQueryClient() - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should not trigger load more when already loading', async () => { @@ -1724,7 +1769,7 @@ describe('AppSelector', () => { vi.advanceTimersByTime(500) }) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should render load more area when hasMore is true', () => { @@ -1735,10 +1780,11 @@ describe('AppSelector', () => { renderWithQueryClient() // Open the portal - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // Should render without errors - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + // Should render without errors + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) it('should handle fetchNextPage rejection gracefully in handleLoadMore', async () => { @@ -1748,7 +1794,8 @@ describe('AppSelector', () => { renderWithQueryClient() // Should not crash even if fetchNextPage rejects - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + // Should not crash even if fetchNextPage rejects + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should call fetchNextPage when intersection observer triggers handleLoadMore', async () => { @@ -1759,11 +1806,11 @@ describe('AppSelector', () => { renderWithQueryClient() // Open the main portal - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // Open the inner app picker portal const triggers = screen.getAllByTestId('portal-trigger') - fireEvent.click(triggers[1]) + fireEvent.click(triggers[1]!) // Simulate intersection to trigger handleLoadMore triggerIntersection([{ isIntersecting: true } as IntersectionObserverEntry]) @@ -1780,9 +1827,9 @@ describe('AppSelector', () => { renderWithQueryClient() // Open portals - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) const triggers = screen.getAllByTestId('portal-trigger') - fireEvent.click(triggers[1]) + fireEvent.click(triggers[1]!) // Trigger first intersection triggerIntersection([{ isIntersecting: true } as IntersectionObserverEntry]) @@ -1806,9 +1853,9 @@ describe('AppSelector', () => { renderWithQueryClient() // Open portals - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) const triggers = screen.getAllByTestId('portal-trigger') - fireEvent.click(triggers[1]) + fireEvent.click(triggers[1]!) // Trigger intersection triggerIntersection([{ isIntersecting: true } as IntersectionObserverEntry]) @@ -1825,9 +1872,9 @@ describe('AppSelector', () => { renderWithQueryClient() // Open portals - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) const triggers = screen.getAllByTestId('portal-trigger') - fireEvent.click(triggers[1]) + fireEvent.click(triggers[1]!) // Trigger intersection triggerIntersection([{ isIntersecting: true } as IntersectionObserverEntry]) @@ -1847,7 +1894,7 @@ describe('AppSelector', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle form change without image file', () => { @@ -1860,7 +1907,7 @@ describe('AppSelector', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should extract #image# from inputs and add to files array', () => { @@ -1874,7 +1921,7 @@ describe('AppSelector', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should preserve existing files when no #image# in inputs', () => { @@ -1887,7 +1934,7 @@ describe('AppSelector', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) }) @@ -1907,9 +1954,9 @@ describe('AppSelector', () => { ) // Open the main portal - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) it('should preserve inputs when selecting the same app', () => { @@ -1926,7 +1973,7 @@ describe('AppSelector', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle app selection with empty value', () => { @@ -1944,34 +1991,34 @@ describe('AppSelector', () => { ) // Open the main portal - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) }) describe('Edge Cases', () => { it('should handle undefined value', () => { renderWithQueryClient() - expect(screen.getByText('app.appSelector.placeholder')).toBeInTheDocument() + expect(screen.getByText('app.appSelector.placeholder'))!.toBeInTheDocument() }) it('should handle empty pages array', () => { mockAppListData = { pages: [] } renderWithQueryClient() - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle undefined data', () => { mockAppListData = undefined renderWithQueryClient() - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle loading state', () => { mockIsLoading = true renderWithQueryClient() - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle app not found in displayedApps', () => { @@ -1986,7 +2033,7 @@ describe('AppSelector', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle value with empty inputs and files', () => { @@ -1997,7 +2044,7 @@ describe('AppSelector', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) }) @@ -2009,7 +2056,8 @@ describe('AppSelector', () => { renderWithQueryClient() // Should not crash - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + // Should not crash + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) }) }) @@ -2043,10 +2091,11 @@ describe('AppSelector Integration', () => { renderWithQueryClient() // 1. Click trigger to open picker - get first trigger (outer portal) - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // Get the first portal element (outer portal) - expect(screen.getAllByTestId('portal-to-follow-elem')[0]).toHaveAttribute('data-open', 'true') + // Get the first portal element (outer portal) + expect(screen.getAllByTestId('portal-to-follow-elem')[0])!.toHaveAttribute('data-open', 'true') }) it('should handle app change with input preservation logic', () => { @@ -2058,7 +2107,7 @@ describe('AppSelector Integration', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) }) @@ -2067,7 +2116,8 @@ describe('AppSelector Integration', () => { renderWithQueryClient() // AppTrigger should show placeholder when no app selected - expect(screen.getByText('app.appSelector.placeholder')).toBeInTheDocument() + // AppTrigger should show placeholder when no app selected + expect(screen.getByText('app.appSelector.placeholder'))!.toBeInTheDocument() }) it('should pass correct props to AppPicker', () => { @@ -2075,7 +2125,7 @@ describe('AppSelector Integration', () => { fireEvent.click(screen.getByTestId('portal-trigger')) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) }) @@ -2088,7 +2138,7 @@ describe('AppSelector Integration', () => { />, ) - expect(screen.getByTestId('portal-to-follow-elem')).toBeInTheDocument() + expect(screen.getByTestId('portal-to-follow-elem'))!.toBeInTheDocument() }) it('should handle search filtering through app list', () => { @@ -2096,7 +2146,7 @@ describe('AppSelector Integration', () => { fireEvent.click(screen.getByTestId('portal-trigger')) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) }) @@ -2115,13 +2165,13 @@ describe('AppSelector Integration', () => { ) // Open the main portal - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // The inner AppPicker portal is closed by default (isShowChooseApp = false) // We need to click on the inner trigger to open it const innerTriggers = screen.getAllByTestId('portal-trigger') // The second trigger is the inner AppPicker trigger - fireEvent.click(innerTriggers[1]) + fireEvent.click(innerTriggers[1]!) // Now the inner portal should be open and show the app list // Find and click on app-2 @@ -2150,16 +2200,16 @@ describe('AppSelector Integration', () => { ) // Open the main portal - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // Click on the inner trigger to open app picker const innerTriggers = screen.getAllByTestId('portal-trigger') - fireEvent.click(innerTriggers[1]) + fireEvent.click(innerTriggers[1]!) // Click on the same app - need to get the one in the app list, not the trigger const appItems = screen.getAllByText('App 1') // The last one should be in the dropdown list - fireEvent.click(appItems[appItems.length - 1]) + fireEvent.click(appItems[appItems.length - 1]!) // onSelect should be called with preserved inputs since it's the same app expect(onSelect).toHaveBeenCalledWith({ @@ -2183,15 +2233,15 @@ describe('AppSelector Integration', () => { ) // Open the main portal - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // Click on inner trigger to open app picker const innerTriggers = screen.getAllByTestId('portal-trigger') - fireEvent.click(innerTriggers[1]) + fireEvent.click(innerTriggers[1]!) // Click on an app from the dropdown const app1Elements = screen.getAllByText('App 1') - fireEvent.click(app1Elements[app1Elements.length - 1]) + fireEvent.click(app1Elements[app1Elements.length - 1]!) // onSelect should be called with new app and empty inputs/files expect(onSelect).toHaveBeenCalledWith({ @@ -2211,9 +2261,9 @@ describe('AppSelector Integration', () => { renderWithQueryClient() // Open the portal to render the app picker - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) it('should stay stable after fetchNextPage completes', async () => { @@ -2223,9 +2273,9 @@ describe('AppSelector Integration', () => { renderWithQueryClient() - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) it('should not call fetchNextPage when conditions prevent it', () => { @@ -2234,7 +2284,7 @@ describe('AppSelector Integration', () => { renderWithQueryClient() - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // fetchNextPage should not be called expect(mockFetchNextPage).not.toHaveBeenCalled() @@ -2256,7 +2306,7 @@ describe('AppSelector Integration', () => { ) // Open portal - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // formattedValue should include #image# from files expect(screen.getAllByTestId('portal-content').length).toBeGreaterThan(0) @@ -2275,7 +2325,7 @@ describe('AppSelector Integration', () => { />, ) - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) expect(screen.getAllByTestId('portal-content').length).toBeGreaterThan(0) }) @@ -2293,7 +2343,7 @@ describe('AppSelector Integration', () => { />, ) - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) expect(screen.getAllByTestId('portal-content').length).toBeGreaterThan(0) }) @@ -2324,12 +2374,12 @@ describe('AppSelector Integration', () => { ) // Open portal to render AppInputsPanel - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // Find and interact with the form input (may not exist if schema is empty) const formInputs = screen.queryAllByPlaceholderText('FormInputField') if (formInputs.length > 0) { - fireEvent.change(formInputs[0], { target: { value: 'test value' } }) + fireEvent.change(formInputs[0]!, { target: { value: 'test value' } }) // handleFormChange in index.tsx should have been called expect(onSelect).toHaveBeenCalledWith({ @@ -2376,12 +2426,12 @@ describe('AppSelector Integration', () => { />, ) - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // Find file uploader and trigger upload - the #image# field will be extracted const uploadBtns = screen.queryAllByTestId('upload-file-btn') if (uploadBtns.length > 0) { - fireEvent.click(uploadBtns[0]) + fireEvent.click(uploadBtns[0]!) // handleFormChange should extract #image# and convert to files expect(onSelect).toHaveBeenCalled() } @@ -2414,12 +2464,12 @@ describe('AppSelector Integration', () => { />, ) - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // Find form input (may not exist if schema is empty) const inputs = screen.queryAllByPlaceholderText('PreserveField') if (inputs.length > 0) { - fireEvent.change(inputs[0], { target: { value: 'updated name' } }) + fireEvent.change(inputs[0]!, { target: { value: 'updated name' } }) // onSelect should be called preserving existing files (no #image# in inputs) expect(onSelect).toHaveBeenCalledWith({ @@ -2465,7 +2515,7 @@ describe('AppSelector Integration', () => { />, ) - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) // Try to find and click the upload button which triggers #image# form change const uploadBtn = screen.queryByTestId('upload-file-btn') @@ -2499,11 +2549,11 @@ describe('AppSelector Integration', () => { />, ) - fireEvent.click(screen.getAllByTestId('portal-trigger')[0]) + fireEvent.click(screen.getAllByTestId('portal-trigger')[0]!) const inputs = screen.queryAllByPlaceholderText('SimpleInput') if (inputs.length > 0) { - fireEvent.change(inputs[0], { target: { value: 'changed' } }) + fireEvent.change(inputs[0]!, { target: { value: 'changed' } }) // handleFormChange should preserve existing files when no #image# in inputs expect(onSelect).toHaveBeenCalledWith({ app_id: 'app-1', diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx index c4cb4f4da8..41140ac63b 100644 --- a/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx @@ -78,7 +78,7 @@ const AppPicker: FC = ({ const handleIntersection = useCallback((entries: IntersectionObserverEntry[]) => { const target = entries[0] - if (!target.isIntersecting || loadingRef.current || !hasMore || isLoading) + if (!target!.isIntersecting || loadingRef.current || !hasMore || isLoading) return loadingRef.current = true @@ -188,7 +188,7 @@ const AppPicker: FC = ({ {apps.map(app => (
onSelect(app)} > = ({ background={app.icon_background} imageUrl={app.icon_url} /> -
+
{app.name} ( @@ -207,7 +207,7 @@ const AppPicker: FC = ({ )
-
{getAppType(app)}
+
{getAppType(app)}
))}
diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx index abd47b4592..97e144af6f 100644 --- a/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx @@ -161,7 +161,7 @@ const AppSelector: FC = ({
-
{t('appSelector.label', { ns: 'app' })}
+
{t('appSelector.label', { ns: 'app' })}
{ const { container } = render() // Assert - expect(container).toBeInTheDocument() + // Assert + expect(container)!.toBeInTheDocument() }) it('should render language label', () => { @@ -172,7 +173,8 @@ describe('TTSParamsPanel', () => { render() // Assert - expect(screen.getByText('appDebug.voice.voiceSettings.language')).toBeInTheDocument() + // Assert + expect(screen.getByText('appDebug.voice.voiceSettings.language'))!.toBeInTheDocument() }) it('should render voice label', () => { @@ -183,7 +185,8 @@ describe('TTSParamsPanel', () => { render() // Assert - expect(screen.getByText('appDebug.voice.voiceSettings.voice')).toBeInTheDocument() + // Assert + expect(screen.getByText('appDebug.voice.voiceSettings.voice'))!.toBeInTheDocument() }) it('should render two Select components', () => { @@ -207,7 +210,7 @@ describe('TTSParamsPanel', () => { // Assert const values = screen.getAllByTestId('selected-value') - expect(values[0]).toHaveTextContent('zh-Hans') + expect(values[0])!.toHaveTextContent('zh-Hans') }) it('should render voice select with correct value', () => { @@ -219,7 +222,7 @@ describe('TTSParamsPanel', () => { // Assert const values = screen.getAllByTestId('selected-value') - expect(values[1]).toHaveTextContent('echo') + expect(values[1])!.toHaveTextContent('echo') }) it('should only show supported languages in language select', () => { @@ -230,9 +233,10 @@ describe('TTSParamsPanel', () => { render() // Assert - expect(screen.getByTestId('select-item-en-US')).toBeInTheDocument() - expect(screen.getByTestId('select-item-zh-Hans')).toBeInTheDocument() - expect(screen.getByTestId('select-item-ja-JP')).toBeInTheDocument() + // Assert + expect(screen.getByTestId('select-item-en-US'))!.toBeInTheDocument() + expect(screen.getByTestId('select-item-zh-Hans'))!.toBeInTheDocument() + expect(screen.getByTestId('select-item-ja-JP'))!.toBeInTheDocument() expect(screen.queryByTestId('select-item-unsupported-lang')).not.toBeInTheDocument() }) @@ -244,9 +248,10 @@ describe('TTSParamsPanel', () => { render() // Assert - expect(screen.getByTestId('select-item-alloy')).toBeInTheDocument() - expect(screen.getByTestId('select-item-echo')).toBeInTheDocument() - expect(screen.getByTestId('select-item-fable')).toBeInTheDocument() + // Assert + expect(screen.getByTestId('select-item-alloy'))!.toBeInTheDocument() + expect(screen.getByTestId('select-item-echo'))!.toBeInTheDocument() + expect(screen.getByTestId('select-item-fable'))!.toBeInTheDocument() }) }) @@ -260,8 +265,9 @@ describe('TTSParamsPanel', () => { render() // Assert - expect(screen.getByTestId('tts-language-select-trigger')).toHaveAttribute('data-class', 'w-full') - expect(screen.getByTestId('tts-voice-select-trigger')).toHaveAttribute('data-class', 'w-full') + // Assert + expect(screen.getByTestId('tts-language-select-trigger'))!.toHaveAttribute('data-class', 'w-full') + expect(screen.getByTestId('tts-voice-select-trigger'))!.toHaveAttribute('data-class', 'w-full') }) it('should apply popup className to SelectContent', () => { @@ -273,8 +279,8 @@ describe('TTSParamsPanel', () => { // Assert const contents = screen.getAllByTestId('select-content') - expect(contents[0]).toHaveAttribute('data-popup-class', 'w-[354px]') - expect(contents[1]).toHaveAttribute('data-popup-class', 'w-[354px]') + expect(contents[0])!.toHaveAttribute('data-popup-class', 'w-[354px]') + expect(contents[1])!.toHaveAttribute('data-popup-class', 'w-[354px]') }) }) @@ -396,6 +402,37 @@ describe('TTSParamsPanel', () => { // Act render() + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered + // Assert - no voice items should be rendered // Assert - no voice items should be rendered expect(screen.queryByTestId('select-item-alloy')).not.toBeInTheDocument() expect(screen.queryByTestId('select-item-echo')).not.toBeInTheDocument() @@ -413,6 +450,37 @@ describe('TTSParamsPanel', () => { // Act render() + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert // Assert expect(screen.queryByTestId('select-item-alloy')).not.toBeInTheDocument() }) @@ -430,8 +498,9 @@ describe('TTSParamsPanel', () => { render() // Assert - expect(screen.getByTestId('select-item-voice-1')).toBeInTheDocument() - expect(screen.getByTestId('select-item-voice-2')).toBeInTheDocument() + // Assert + expect(screen.getByTestId('select-item-voice-1'))!.toBeInTheDocument() + expect(screen.getByTestId('select-item-voice-2'))!.toBeInTheDocument() }) it('should handle currentModel with empty voices array', () => { @@ -444,7 +513,7 @@ describe('TTSParamsPanel', () => { render() // Assert - no voice items (except language items) - expect(screen.getAllByTestId('select-content')[1].children).toHaveLength(0) + expect(screen.getAllByTestId('select-content')[1]!.children).toHaveLength(0) expect(screen.queryByTestId('select-item-alloy')).not.toBeInTheDocument() }) @@ -460,7 +529,8 @@ describe('TTSParamsPanel', () => { render() // Assert - expect(screen.getByTestId('select-item-single-voice')).toBeInTheDocument() + // Assert + expect(screen.getByTestId('select-item-single-voice'))!.toBeInTheDocument() }) }) @@ -475,7 +545,7 @@ describe('TTSParamsPanel', () => { // Assert const values = screen.getAllByTestId('selected-value') - expect(values[0]).toHaveTextContent('') + expect(values[0])!.toHaveTextContent('') }) it('should handle empty voice value', () => { @@ -487,7 +557,7 @@ describe('TTSParamsPanel', () => { // Assert const values = screen.getAllByTestId('selected-value') - expect(values[1]).toHaveTextContent('') + expect(values[1])!.toHaveTextContent('') }) it('should handle many voices', () => { @@ -504,8 +574,9 @@ describe('TTSParamsPanel', () => { render() // Assert - expect(screen.getByTestId('select-item-voice-0')).toBeInTheDocument() - expect(screen.getByTestId('select-item-voice-19')).toBeInTheDocument() + // Assert + expect(screen.getByTestId('select-item-voice-0'))!.toBeInTheDocument() + expect(screen.getByTestId('select-item-voice-19'))!.toBeInTheDocument() }) it('should handle voice with special characters in mode', () => { @@ -520,7 +591,8 @@ describe('TTSParamsPanel', () => { render() // Assert - expect(screen.getByTestId('select-item-voice-with_special.chars')).toBeInTheDocument() + // Assert + expect(screen.getByTestId('select-item-voice-with_special.chars'))!.toBeInTheDocument() }) it('should handle onChange not being called multiple times', () => { @@ -546,13 +618,13 @@ describe('TTSParamsPanel', () => { // Act const { rerender } = render() const values = screen.getAllByTestId('selected-value') - expect(values[0]).toHaveTextContent('en-US') + expect(values[0])!.toHaveTextContent('en-US') rerender() // Assert const updatedValues = screen.getAllByTestId('selected-value') - expect(updatedValues[0]).toHaveTextContent('zh-Hans') + expect(updatedValues[0])!.toHaveTextContent('zh-Hans') }) it('should update when voice prop changes', () => { @@ -562,13 +634,13 @@ describe('TTSParamsPanel', () => { // Act const { rerender } = render() const values = screen.getAllByTestId('selected-value') - expect(values[1]).toHaveTextContent('alloy') + expect(values[1])!.toHaveTextContent('alloy') rerender() // Assert const updatedValues = screen.getAllByTestId('selected-value') - expect(updatedValues[1]).toHaveTextContent('echo') + expect(updatedValues[1])!.toHaveTextContent('echo') }) it('should update voice list when currentModel changes', () => { @@ -580,7 +652,7 @@ describe('TTSParamsPanel', () => { // Act const { rerender } = render() - expect(screen.getByTestId('select-item-alloy')).toBeInTheDocument() + expect(screen.getByTestId('select-item-alloy'))!.toBeInTheDocument() expect(screen.queryByTestId('select-item-nova')).not.toBeInTheDocument() const newModel = createCurrentModel([ @@ -590,8 +662,9 @@ describe('TTSParamsPanel', () => { rerender() // Assert - expect(screen.getByTestId('select-item-alloy')).toBeInTheDocument() - expect(screen.getByTestId('select-item-nova')).toBeInTheDocument() + // Assert + expect(screen.getByTestId('select-item-alloy'))!.toBeInTheDocument() + expect(screen.getByTestId('select-item-nova'))!.toBeInTheDocument() }) it('should handle currentModel becoming null', () => { @@ -600,10 +673,41 @@ describe('TTSParamsPanel', () => { // Act const { rerender } = render() - expect(screen.getByTestId('select-item-alloy')).toBeInTheDocument() + expect(screen.getByTestId('select-item-alloy'))!.toBeInTheDocument() rerender() + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert // Assert expect(screen.queryByTestId('select-item-alloy')).not.toBeInTheDocument() }) @@ -636,7 +740,7 @@ describe('TTSParamsPanel', () => { // Assert const languageLabel = screen.getByText('appDebug.voice.voiceSettings.language') - expect(languageLabel).toHaveClass('system-sm-semibold') + expect(languageLabel)!.toHaveClass('system-sm-semibold') }) it('should have proper label structure for voice select', () => { @@ -648,7 +752,7 @@ describe('TTSParamsPanel', () => { // Assert const voiceLabel = screen.getByText('appDebug.voice.voiceSettings.voice') - expect(voiceLabel).toHaveClass('system-sm-semibold') + expect(voiceLabel)!.toHaveClass('system-sm-semibold') }) }) }) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/index.spec.tsx index d41dfaa7d0..19a2d5a9f1 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/index.spec.tsx @@ -100,8 +100,8 @@ describe('SubscriptionList', () => { it('should render list view by default', () => { render() - expect(screen.getByText(/pluginTrigger\.subscription\.listNum/)).toBeInTheDocument() - expect(screen.getByText('Subscription One')).toBeInTheDocument() + expect(screen.getByText(/pluginTrigger\.subscription\.listNum/))!.toBeInTheDocument() + expect(screen.getByText('Subscription One'))!.toBeInTheDocument() }) it('should render loading state when subscriptions are loading', () => { @@ -112,7 +112,7 @@ describe('SubscriptionList', () => { render() - expect(screen.getByRole('status')).toBeInTheDocument() + expect(screen.getByRole('status'))!.toBeInTheDocument() expect(screen.queryByText('Subscription One')).not.toBeInTheDocument() }) @@ -121,7 +121,7 @@ describe('SubscriptionList', () => { render() - expect(screen.getByText('Subscription One')).toBeInTheDocument() + expect(screen.getByText('Subscription One'))!.toBeInTheDocument() }) it('should render without list entries when subscriptions are empty', () => { @@ -141,7 +141,7 @@ describe('SubscriptionList', () => { it('should render selector view when mode is selector', () => { render() - expect(screen.getByText('Subscription One')).toBeInTheDocument() + expect(screen.getByText('Subscription One'))!.toBeInTheDocument() }) it('should visually distinguish selected subscription from unselected', () => { @@ -182,7 +182,7 @@ describe('SubscriptionList', () => { fireEvent.click(screen.getByRole('button', { name: 'Subscription One' })) expect(onSelect).toHaveBeenCalledTimes(1) - const [selectedSubscription, callback] = onSelect.mock.calls[0] + const [selectedSubscription, callback] = (onSelect.mock.calls[0] ?? []) as [any, any] expect(selectedSubscription).toMatchObject({ id: 'sub-1', name: 'Subscription One' }) expect(typeof callback).toBe('function') @@ -212,7 +212,7 @@ describe('SubscriptionList', () => { fireEvent.click(deleteButton) expect(onSelect).not.toHaveBeenCalled() - expect(screen.getByText(/pluginTrigger\.subscription\.list\.item\.actions\.deleteConfirm\.title/)).toBeInTheDocument() + expect(screen.getByText(/pluginTrigger\.subscription\.list\.item\.actions\.deleteConfirm\.title/))!.toBeInTheDocument() }) }) @@ -222,7 +222,7 @@ describe('SubscriptionList', () => { render() - expect(await screen.findByText('Something went wrong')).toBeInTheDocument() + expect(await screen.findByText('Something went wrong'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.tsx index 8a3642d043..4e6be7d81c 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.tsx @@ -137,7 +137,7 @@ const ReasoningConfigForm: React.FC = ({ asChild={false} /> )) - const varInput = value[variable].value + const varInput = value[variable]!.value const { isString, isNumber, @@ -179,7 +179,7 @@ const ReasoningConfigForm: React.FC = ({ >
showSchema(input_schema as SchemaRoot, fieldTitle)} + onClick={() => showSchema(input_schema as SchemaRoot, fieldTitle!)} >
diff --git a/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx b/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx index 82c4d69d1b..0b0d9c7fc8 100644 --- a/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx +++ b/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx @@ -253,10 +253,11 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Assert - expect(screen.getByText('plugin.action.delete')).toBeInTheDocument() + // Assert + expect(screen.getByText('plugin.action.delete'))!.toBeInTheDocument() }) it('should display plugin name in delete confirm content', () => { @@ -270,10 +271,11 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Assert - expect(screen.getByText('my-awesome-plugin')).toBeInTheDocument() + // Assert + expect(screen.getByText('my-awesome-plugin'))!.toBeInTheDocument() }) it('should hide confirm modal when cancel is clicked', () => { @@ -286,8 +288,8 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) - expect(screen.getByText('plugin.action.delete')).toBeInTheDocument() + fireEvent.click(getActionButtons()[0]!) + expect(screen.getByText('plugin.action.delete'))!.toBeInTheDocument() fireEvent.click(getDeleteCancelButton()) @@ -308,7 +310,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) // Assert @@ -330,7 +332,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) // Assert @@ -352,7 +354,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) // Assert @@ -374,7 +376,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) // Assert @@ -401,12 +403,12 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) // Assert - Loading state await waitFor(() => { - expect(getDeleteConfirmButton()).toBeDisabled() + expect(getDeleteConfirmButton())!.toBeDisabled() }) // Resolve and check modal closes @@ -434,13 +436,14 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Assert - expect(screen.getByTestId('plugin-info-modal')).toBeInTheDocument() - expect(screen.getByTestId('plugin-info-modal')).toHaveAttribute('data-repo', 'owner/repo-name') - expect(screen.getByTestId('plugin-info-modal')).toHaveAttribute('data-release', '2.0.0') - expect(screen.getByTestId('plugin-info-modal')).toHaveAttribute('data-package', 'my-package.difypkg') + // Assert + expect(screen.getByTestId('plugin-info-modal'))!.toBeInTheDocument() + expect(screen.getByTestId('plugin-info-modal'))!.toHaveAttribute('data-repo', 'owner/repo-name') + expect(screen.getByTestId('plugin-info-modal'))!.toHaveAttribute('data-release', '2.0.0') + expect(screen.getByTestId('plugin-info-modal'))!.toHaveAttribute('data-package', 'my-package.difypkg') }) it('should hide plugin info modal when close is clicked', () => { @@ -453,11 +456,42 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) - expect(screen.getByTestId('plugin-info-modal')).toBeInTheDocument() + fireEvent.click(getActionButtons()[0]!) + expect(screen.getByTestId('plugin-info-modal'))!.toBeInTheDocument() fireEvent.click(screen.getByTestId('close-plugin-info')) + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert // Assert expect(screen.queryByTestId('plugin-info-modal')).not.toBeInTheDocument() }) @@ -481,7 +515,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Assert await waitFor(() => { @@ -507,7 +541,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Assert await waitFor(() => { @@ -526,7 +560,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Assert await waitFor(() => { @@ -550,7 +584,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Assert - toast is called with the translated payload await waitFor(() => { @@ -581,7 +615,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Assert await waitFor(() => { @@ -621,7 +655,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Wait for modal to be called await waitFor(() => { @@ -629,7 +663,7 @@ describe('Action Component', () => { }) // Invoke the callback - const call = mockSetShowUpdatePluginModal.mock.calls[0][0] + const call = mockSetShowUpdatePluginModal.mock.calls[0]![0] call.onSaveCallback() // Assert @@ -653,7 +687,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Assert await waitFor(() => { @@ -678,7 +712,7 @@ describe('Action Component', () => { // Act - First render and delete const { rerender } = render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) await waitFor(() => { @@ -688,7 +722,7 @@ describe('Action Component', () => { // Re-render with same props mockUninstallPlugin.mockClear() rerender() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) await waitFor(() => { @@ -714,7 +748,7 @@ describe('Action Component', () => { // Act const { rerender } = render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) await waitFor(() => { @@ -723,7 +757,7 @@ describe('Action Component', () => { mockUninstallPlugin.mockClear() rerender() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) await waitFor(() => { @@ -751,7 +785,7 @@ describe('Action Component', () => { // Act const { rerender } = render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) await waitFor(() => { @@ -760,7 +794,7 @@ describe('Action Component', () => { expect(onDelete2).not.toHaveBeenCalled() rerender() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) await waitFor(() => { @@ -802,7 +836,7 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Assert - Should use author and pluginName as fallback await waitFor(() => { @@ -826,11 +860,12 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) fireEvent.click(getDeleteConfirmButton()) // The confirm button should be disabled during deletion - expect(getDeleteConfirmButton()).toBeDisabled() + // The confirm button should be disabled during deletion + expect(getDeleteConfirmButton())!.toBeDisabled() // Resolve the deletion resolveFirst!({ success: true }) @@ -851,10 +886,11 @@ describe('Action Component', () => { // Act render() - fireEvent.click(getActionButtons()[0]) + fireEvent.click(getActionButtons()[0]!) // Assert - expect(screen.getByText('plugin-with-special@chars#123')).toBeInTheDocument() + // Assert + expect(screen.getByText('plugin-with-special@chars#123'))!.toBeInTheDocument() }) }) diff --git a/web/app/components/plugins/plugin-page/plugin-tasks/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-page/plugin-tasks/__tests__/index.spec.tsx index 8c33894929..ab3382ff75 100644 --- a/web/app/components/plugins/plugin-page/plugin-tasks/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-page/plugin-tasks/__tests__/index.spec.tsx @@ -83,8 +83,8 @@ describe('usePluginTaskStatus Hook', () => { render() - expect(screen.getByTestId('running-count')).toHaveTextContent('1') - expect(screen.getByTestId('running-id')).toHaveTextContent(runningPlugin.plugin_unique_identifier) + expect(screen.getByTestId('running-count'))!.toHaveTextContent('1') + expect(screen.getByTestId('running-id'))!.toHaveTextContent(runningPlugin.plugin_unique_identifier) }) it('should categorize success plugins correctly', () => { @@ -103,8 +103,8 @@ describe('usePluginTaskStatus Hook', () => { render() - expect(screen.getByTestId('success-count')).toHaveTextContent('1') - expect(screen.getByTestId('success-id')).toHaveTextContent(successPlugin.plugin_unique_identifier) + expect(screen.getByTestId('success-count'))!.toHaveTextContent('1') + expect(screen.getByTestId('success-id'))!.toHaveTextContent(successPlugin.plugin_unique_identifier) }) it('should categorize error plugins correctly', () => { @@ -123,8 +123,8 @@ describe('usePluginTaskStatus Hook', () => { render() - expect(screen.getByTestId('error-count')).toHaveTextContent('1') - expect(screen.getByTestId('error-id')).toHaveTextContent(errorPlugin.plugin_unique_identifier) + expect(screen.getByTestId('error-count'))!.toHaveTextContent('1') + expect(screen.getByTestId('error-id'))!.toHaveTextContent(errorPlugin.plugin_unique_identifier) }) it('should categorize mixed plugins correctly', () => { @@ -149,10 +149,10 @@ describe('usePluginTaskStatus Hook', () => { render() - expect(screen.getByTestId('running')).toHaveTextContent('1') - expect(screen.getByTestId('success')).toHaveTextContent('1') - expect(screen.getByTestId('error')).toHaveTextContent('1') - expect(screen.getByTestId('total')).toHaveTextContent('3') + expect(screen.getByTestId('running'))!.toHaveTextContent('1') + expect(screen.getByTestId('success'))!.toHaveTextContent('1') + expect(screen.getByTestId('error'))!.toHaveTextContent('1') + expect(screen.getByTestId('total'))!.toHaveTextContent('3') }) }) @@ -175,11 +175,11 @@ describe('usePluginTaskStatus Hook', () => { render() - expect(screen.getByTestId('isInstalling')).toHaveTextContent('true') - expect(screen.getByTestId('isInstallingWithSuccess')).toHaveTextContent('false') - expect(screen.getByTestId('isInstallingWithError')).toHaveTextContent('false') - expect(screen.getByTestId('isSuccess')).toHaveTextContent('false') - expect(screen.getByTestId('isFailed')).toHaveTextContent('false') + expect(screen.getByTestId('isInstalling'))!.toHaveTextContent('true') + expect(screen.getByTestId('isInstallingWithSuccess'))!.toHaveTextContent('false') + expect(screen.getByTestId('isInstallingWithError'))!.toHaveTextContent('false') + expect(screen.getByTestId('isSuccess'))!.toHaveTextContent('false') + expect(screen.getByTestId('isFailed'))!.toHaveTextContent('false') }) it('should set isInstallingWithSuccess when running and success plugins exist', () => { @@ -194,7 +194,7 @@ describe('usePluginTaskStatus Hook', () => { } render() - expect(screen.getByTestId('flag')).toHaveTextContent('true') + expect(screen.getByTestId('flag'))!.toHaveTextContent('true') }) it('should set isInstallingWithError when running and error plugins exist', () => { @@ -209,7 +209,7 @@ describe('usePluginTaskStatus Hook', () => { } render() - expect(screen.getByTestId('flag')).toHaveTextContent('true') + expect(screen.getByTestId('flag'))!.toHaveTextContent('true') }) it('should set isSuccess when all plugins succeeded', () => { @@ -224,7 +224,7 @@ describe('usePluginTaskStatus Hook', () => { } render() - expect(screen.getByTestId('flag')).toHaveTextContent('true') + expect(screen.getByTestId('flag'))!.toHaveTextContent('true') }) it('should set isFailed when no running plugins and some failed', () => { @@ -239,7 +239,7 @@ describe('usePluginTaskStatus Hook', () => { } render() - expect(screen.getByTestId('flag')).toHaveTextContent('true') + expect(screen.getByTestId('flag'))!.toHaveTextContent('true') }) }) @@ -296,12 +296,12 @@ describe('TaskStatusIndicator Component', () => { describe('Rendering', () => { it('should render without crashing', () => { render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should render with correct id', () => { render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) }) @@ -309,17 +309,18 @@ describe('TaskStatusIndicator Component', () => { it('should show downloading icon when installing', () => { render() // DownloadingIcon is rendered when isInstalling is true - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + // DownloadingIcon is rendered when isInstalling is true + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should show downloading icon when installing with error', () => { render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should show install icon when not installing', () => { render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) }) @@ -333,7 +334,7 @@ describe('TaskStatusIndicator Component', () => { totalPluginsLength={3} />, ) - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should show progress circle when installing with success', () => { @@ -345,7 +346,7 @@ describe('TaskStatusIndicator Component', () => { totalPluginsLength={3} />, ) - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should show error progress circle when installing with error', () => { @@ -357,7 +358,7 @@ describe('TaskStatusIndicator Component', () => { totalPluginsLength={3} />, ) - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should show success icon when all completed successfully', () => { @@ -370,12 +371,12 @@ describe('TaskStatusIndicator Component', () => { totalPluginsLength={3} />, ) - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should show error icon when failed', () => { render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) }) @@ -383,19 +384,19 @@ describe('TaskStatusIndicator Component', () => { it('should apply error styles when installing with error', () => { render() const trigger = document.getElementById('plugin-task-trigger') - expect(trigger).toHaveClass('bg-state-destructive-hover') + expect(trigger)!.toHaveClass('bg-state-destructive-hover') }) it('should apply error styles when failed', () => { render() const trigger = document.getElementById('plugin-task-trigger') - expect(trigger).toHaveClass('bg-state-destructive-hover') + expect(trigger)!.toHaveClass('bg-state-destructive-hover') }) it('should apply cursor-pointer when clickable', () => { render() const trigger = document.getElementById('plugin-task-trigger') - expect(trigger).toHaveClass('cursor-pointer') + expect(trigger)!.toHaveClass('cursor-pointer') }) }) @@ -429,7 +430,7 @@ describe('PluginTaskList Component', () => { describe('Rendering', () => { it('should render without crashing with empty lists', () => { render() - expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + expect(document.querySelector('.w-\\[360px\\]'))!.toBeInTheDocument() }) it('should render running plugins section when plugins exist', () => { @@ -439,7 +440,8 @@ describe('PluginTaskList Component', () => { // Translation key is returned as text in tests, multiple matches expected (title + status) expect(screen.getAllByText(/task\.installing/i).length).toBeGreaterThan(0) // Verify section container is rendered - expect(document.querySelector('.max-h-\\[300px\\]')).toBeInTheDocument() + // Verify section container is rendered + expect(document.querySelector('.max-h-\\[300px\\]'))!.toBeInTheDocument() }) it('should render success plugins section when plugins exist', () => { @@ -454,7 +456,7 @@ describe('PluginTaskList Component', () => { const errorPlugins = [createMockPlugin({ status: TaskStatus.failed, message: 'Error occurred' })] render() - expect(screen.getByText('Error occurred')).toBeInTheDocument() + expect(screen.getByText('Error occurred'))!.toBeInTheDocument() }) it('should render all sections when all types exist', () => { @@ -541,7 +543,7 @@ describe('PluginTaskList Component', () => { render() - expect(screen.getByText('My Test Plugin')).toBeInTheDocument() + expect(screen.getByText('My Test Plugin'))!.toBeInTheDocument() }) it('should display plugin message when available', () => { @@ -552,7 +554,7 @@ describe('PluginTaskList Component', () => { render() - expect(screen.getByText('Successfully installed!')).toBeInTheDocument() + expect(screen.getByText('Successfully installed!'))!.toBeInTheDocument() }) it('should display multiple plugins in each section', () => { @@ -563,8 +565,8 @@ describe('PluginTaskList Component', () => { render() - expect(screen.getByText('Plugin A')).toBeInTheDocument() - expect(screen.getByText('Plugin B')).toBeInTheDocument() + expect(screen.getByText('Plugin A'))!.toBeInTheDocument() + expect(screen.getByText('Plugin B'))!.toBeInTheDocument() // Count is rendered, verify multiple items are in list expect(document.querySelectorAll('.hover\\:bg-state-base-hover').length).toBe(2) }) @@ -593,7 +595,7 @@ describe('PluginTasks Component', () => { render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) }) @@ -604,7 +606,8 @@ describe('PluginTasks Component', () => { render() // The component renders with a tooltip, we verify it exists - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + // The component renders with a tooltip, we verify it exists + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should show success tip when all succeeded', () => { @@ -612,7 +615,7 @@ describe('PluginTasks Component', () => { render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should show error tip when some failed', () => { @@ -623,7 +626,7 @@ describe('PluginTasks Component', () => { render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) }) @@ -637,7 +640,8 @@ describe('PluginTasks Component', () => { fireEvent.click(document.getElementById('plugin-task-trigger')!) // The popover content should be visible (PluginTaskList) - expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + // The popover content should be visible (PluginTaskList) + expect(document.querySelector('.w-\\[360px\\]'))!.toBeInTheDocument() }) it('should not toggle when status does not allow', () => { @@ -647,7 +651,8 @@ describe('PluginTasks Component', () => { render() // Component should still render - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + // Component should still render + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) }) @@ -665,7 +670,7 @@ describe('PluginTasks Component', () => { // Wait for popover content to render await waitFor(() => { - expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + expect(document.querySelector('.w-\\[360px\\]'))!.toBeInTheDocument() }) // Find and click clear all button @@ -711,13 +716,13 @@ describe('PluginTasks Component', () => { fireEvent.click(document.getElementById('plugin-task-trigger')!) await waitFor(() => { - expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + expect(document.querySelector('.w-\\[360px\\]'))!.toBeInTheDocument() }) // Find and click the clear all button in error section const clearButtons = screen.getAllByRole('button') if (clearButtons.length > 0) - fireEvent.click(clearButtons[0]) + fireEvent.click(clearButtons[0]!) await waitFor(() => { expect(mockMutateAsync).toHaveBeenCalled() @@ -739,13 +744,13 @@ describe('PluginTasks Component', () => { fireEvent.click(document.getElementById('plugin-task-trigger')!) await waitFor(() => { - expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + expect(document.querySelector('.w-\\[360px\\]'))!.toBeInTheDocument() }) // Find and click individual clear button (usually the last one) const clearButtons = screen.getAllByRole('button') const individualClearButton = clearButtons[clearButtons.length - 1] - fireEvent.click(individualClearButton) + fireEvent.click(individualClearButton!) await waitFor(() => { expect(mockMutateAsync).toHaveBeenCalledWith({ @@ -770,7 +775,7 @@ describe('PluginTasks Component', () => { render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should handle many plugins', () => { @@ -783,7 +788,7 @@ describe('PluginTasks Component', () => { render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should handle plugins with empty labels', () => { @@ -795,7 +800,7 @@ describe('PluginTasks Component', () => { render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should handle plugins with long messages', () => { @@ -810,6 +815,30 @@ describe('PluginTasks Component', () => { // Open popover fireEvent.click(document.getElementById('plugin-task-trigger')!) + expect(document.querySelector('.w-\\[360px\\]'))!.toBeInTheDocument() + }) + + it('should open for installing-with-success state', () => { + setupMocks([ + createMockPlugin({ status: TaskStatus.running, plugin_unique_identifier: 'running-1' }), + createMockPlugin({ status: TaskStatus.success, plugin_unique_identifier: 'success-1' }), + ]) + + render() + fireEvent.click(document.getElementById('plugin-task-trigger')!) + + expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + }) + + it('should open for installing-with-error state', () => { + setupMocks([ + createMockPlugin({ status: TaskStatus.running, plugin_unique_identifier: 'running-1' }), + createMockPlugin({ status: TaskStatus.failed, plugin_unique_identifier: 'failed-1' }), + ]) + + render() + fireEvent.click(document.getElementById('plugin-task-trigger')!) + expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() }) @@ -853,13 +882,13 @@ describe('PluginTasks Integration', () => { const { rerender } = render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() // Simulate completion by re-rendering with success setupMocks([createMockPlugin({ status: TaskStatus.success })]) rerender() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should show correct UI flow from installing to failure', async () => { @@ -868,13 +897,13 @@ describe('PluginTasks Integration', () => { const { rerender } = render() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() // Simulate failure by re-rendering with failed setupMocks([createMockPlugin({ status: TaskStatus.failed, message: 'Network error' })]) rerender() - expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + expect(document.getElementById('plugin-task-trigger'))!.toBeInTheDocument() }) it('should handle mixed status during installation', () => { diff --git a/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx b/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx index 5acd193a82..9a39b43bf6 100644 --- a/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx +++ b/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx @@ -33,6 +33,7 @@ const PluginTasks = () => { handleClearErrorPlugin, } = usePluginTaskStatus() const { getIconUrl } = useGetIcon() + const canOpenMenu = isFailed || isInstalling || isInstallingWithSuccess || isInstallingWithError || isSuccess // Generate tooltip text based on status const tip = useMemo(() => { @@ -85,11 +86,6 @@ const PluginTasks = () => { [clearPluginsAndClose], ) - const handleTriggerClick = useCallback(() => { - if (isFailed || isInstalling || isInstallingWithSuccess || isInstallingWithError || isSuccess) - setOpen(v => !v) - }, [isFailed, isInstalling, isInstallingWithSuccess, isInstallingWithError, isSuccess]) - // Hide when no plugin tasks if (totalPluginsLength === 0) return null @@ -102,7 +98,7 @@ const PluginTasks = () => { > } - onClick={handleTriggerClick} + disabled={!canOpenMenu} > { { render() // Assert - expect(screen.getByText('plugin.privilege.title')).toBeInTheDocument() + // Assert + expect(screen.getByText('plugin.privilege.title'))!.toBeInTheDocument() }) it('should render install permission section', () => { @@ -156,7 +157,8 @@ describe('reference-setting-modal', () => { render() // Assert - expect(screen.getByText('plugin.privilege.whoCanInstall')).toBeInTheDocument() + // Assert + expect(screen.getByText('plugin.privilege.whoCanInstall'))!.toBeInTheDocument() }) it('should render debug permission section', () => { @@ -164,7 +166,8 @@ describe('reference-setting-modal', () => { render() // Assert - expect(screen.getByText('plugin.privilege.whoCanDebug')).toBeInTheDocument() + // Assert + expect(screen.getByText('plugin.privilege.whoCanDebug'))!.toBeInTheDocument() }) it('should render all permission options for install', () => { @@ -180,8 +183,9 @@ describe('reference-setting-modal', () => { render() // Assert - expect(screen.getByText('common.operation.cancel')).toBeInTheDocument() - expect(screen.getByText('common.operation.save')).toBeInTheDocument() + // Assert + expect(screen.getByText('common.operation.cancel'))!.toBeInTheDocument() + expect(screen.getByText('common.operation.save'))!.toBeInTheDocument() }) it('should render AutoUpdateSetting when marketplace is enabled', () => { @@ -192,7 +196,8 @@ describe('reference-setting-modal', () => { render() // Assert - expect(screen.getByTestId('auto-update-setting')).toBeInTheDocument() + // Assert + expect(screen.getByTestId('auto-update-setting'))!.toBeInTheDocument() }) it('should not render AutoUpdateSetting when marketplace is disabled', () => { @@ -202,6 +207,37 @@ describe('reference-setting-modal', () => { // Act render() + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert + // Assert // Assert expect(screen.queryByTestId('auto-update-setting')).not.toBeInTheDocument() }) @@ -211,7 +247,8 @@ describe('reference-setting-modal', () => { render() // Assert - expect(screen.getByTestId('modal-close')).toBeInTheDocument() + // Assert + expect(screen.getByTestId('modal-close'))!.toBeInTheDocument() }) }) @@ -230,11 +267,11 @@ describe('reference-setting-modal', () => { // Assert - admin option should be selected for install (first one) const adminOptions = screen.getAllByTestId('option-card-plugin.privilege.admins') - expect(adminOptions[0]).toHaveAttribute('aria-pressed', 'true') // Install permission + expect(adminOptions[0])!.toHaveAttribute('aria-pressed', 'true') // Install permission // Assert - noOne option should be selected for debug (second one) const noOneOptions = screen.getAllByTestId('option-card-plugin.privilege.noone') - expect(noOneOptions[1]).toHaveAttribute('aria-pressed', 'true') // Debug permission + expect(noOneOptions[1])!.toHaveAttribute('aria-pressed', 'true') // Debug permission }) it('should update tempPrivilege when permission option is clicked', () => { @@ -243,10 +280,11 @@ describe('reference-setting-modal', () => { // Act - click on "No One" for install permission const noOneOptions = screen.getAllByTestId('option-card-plugin.privilege.noone') - fireEvent.click(noOneOptions[0]) // First one is for install permission + fireEvent.click(noOneOptions[0]!) // First one is for install permission // Assert - the option should now be selected - expect(noOneOptions[0]).toHaveAttribute('aria-pressed', 'true') + // Assert - the option should now be selected + expect(noOneOptions[0])!.toHaveAttribute('aria-pressed', 'true') }) it('should initialize with payload auto_upgrade values', () => { @@ -261,7 +299,8 @@ describe('reference-setting-modal', () => { render() // Assert - expect(screen.getByTestId('auto-update-strategy')).toHaveTextContent('latest') + // Assert + expect(screen.getByTestId('auto-update-strategy'))!.toHaveTextContent('latest') }) it('should use default auto_upgrade when payload.auto_upgrade is undefined', () => { @@ -275,7 +314,8 @@ describe('reference-setting-modal', () => { render() // Assert - should use default value (disabled) - expect(screen.getByTestId('auto-update-strategy')).toHaveTextContent('disabled') + // Assert - should use default value (disabled) + expect(screen.getByTestId('auto-update-strategy'))!.toHaveTextContent('disabled') }) }) @@ -351,10 +391,11 @@ describe('reference-setting-modal', () => { // Click Everyone for install permission const everyoneOptions = screen.getAllByTestId('option-card-plugin.privilege.everyone') - fireEvent.click(everyoneOptions[0]) + fireEvent.click(everyoneOptions[0]!) // Assert - expect(everyoneOptions[0]).toHaveAttribute('aria-pressed', 'true') + // Assert + expect(everyoneOptions[0])!.toHaveAttribute('aria-pressed', 'true') }) it('should update debug permission when Admins Only option is clicked', () => { @@ -371,10 +412,11 @@ describe('reference-setting-modal', () => { // Click Admins Only for debug permission (second set of options) const adminOptions = screen.getAllByTestId('option-card-plugin.privilege.admins') - fireEvent.click(adminOptions[1]) // Second one is for debug permission + fireEvent.click(adminOptions[1]!) // Second one is for debug permission // Assert - expect(adminOptions[1]).toHaveAttribute('aria-pressed', 'true') + // Assert + expect(adminOptions[1])!.toHaveAttribute('aria-pressed', 'true') }) it('should update auto_upgrade config when changed in AutoUpdateSetting', async () => { @@ -410,7 +452,8 @@ describe('reference-setting-modal', () => { rerender() // Assert - component should render without issues - expect(screen.getByText('plugin.privilege.title')).toBeInTheDocument() + // Assert - component should render without issues + expect(screen.getByText('plugin.privilege.title'))!.toBeInTheDocument() }) it('handleSave should be memoized with useCallback', async () => { @@ -434,10 +477,11 @@ describe('reference-setting-modal', () => { // Act - click install permission option const everyoneOptions = screen.getAllByTestId('option-card-plugin.privilege.everyone') - fireEvent.click(everyoneOptions[0]) + fireEvent.click(everyoneOptions[0]!) // Assert - install permission should be updated - expect(everyoneOptions[0]).toHaveAttribute('aria-pressed', 'true') + // Assert - install permission should be updated + expect(everyoneOptions[0])!.toHaveAttribute('aria-pressed', 'true') }) }) @@ -456,7 +500,7 @@ describe('reference-setting-modal', () => { // Act & Assert - should not crash render() - expect(screen.getByText('plugin.privilege.title')).toBeInTheDocument() + expect(screen.getByText('plugin.privilege.title'))!.toBeInTheDocument() }) it('should handle undefined permission values', () => { @@ -471,7 +515,7 @@ describe('reference-setting-modal', () => { // Assert - should use default PermissionType.noOne const noOneOptions = screen.getAllByTestId('option-card-plugin.privilege.noone') - expect(noOneOptions[0]).toHaveAttribute('aria-pressed', 'true') + expect(noOneOptions[0])!.toHaveAttribute('aria-pressed', 'true') }) it('should handle missing install_permission', () => { @@ -487,7 +531,8 @@ describe('reference-setting-modal', () => { render() // Assert - should fall back to PermissionType.noOne - expect(screen.getByText('plugin.privilege.title')).toBeInTheDocument() + // Assert - should fall back to PermissionType.noOne + expect(screen.getByText('plugin.privilege.title'))!.toBeInTheDocument() }) it('should handle missing debug_permission', () => { @@ -503,7 +548,8 @@ describe('reference-setting-modal', () => { render() // Assert - should fall back to PermissionType.noOne - expect(screen.getByText('plugin.privilege.title')).toBeInTheDocument() + // Assert - should fall back to PermissionType.noOne + expect(screen.getByText('plugin.privilege.title'))!.toBeInTheDocument() }) it('should handle slow async onSave gracefully', async () => { @@ -555,7 +601,8 @@ describe('reference-setting-modal', () => { const { unmount } = render() // Assert - should render without crashing - expect(screen.getByText('plugin.privilege.title')).toBeInTheDocument() + // Assert - should render without crashing + expect(screen.getByText('plugin.privilege.title'))!.toBeInTheDocument() unmount() }) @@ -582,7 +629,8 @@ describe('reference-setting-modal', () => { const { unmount } = render() // Assert - expect(screen.getByTestId('auto-update-strategy')).toHaveTextContent(strategy) + // Assert + expect(screen.getByTestId('auto-update-strategy'))!.toHaveTextContent(strategy) unmount() }) @@ -608,7 +656,8 @@ describe('reference-setting-modal', () => { const { unmount } = render() // Assert - expect(screen.getByTestId('auto-update-mode')).toHaveTextContent(mode) + // Assert + expect(screen.getByTestId('auto-update-mode'))!.toHaveTextContent(mode) unmount() }) @@ -631,7 +680,7 @@ describe('reference-setting-modal', () => { // Change install permission to noOne const noOneOptions = screen.getAllByTestId('option-card-plugin.privilege.noone') - fireEvent.click(noOneOptions[0]) + fireEvent.click(noOneOptions[0]!) // Save fireEvent.click(screen.getByText('common.operation.save')) @@ -662,7 +711,7 @@ describe('reference-setting-modal', () => { // Change debug permission to noOne const noOneOptions = screen.getAllByTestId('option-card-plugin.privilege.noone') - fireEvent.click(noOneOptions[1]) // Second one is for debug + fireEvent.click(noOneOptions[1]!) // Second one is for debug // Save fireEvent.click(screen.getByText('common.operation.save')) @@ -691,7 +740,7 @@ describe('reference-setting-modal', () => { // Change install permission const everyoneOptions = screen.getAllByTestId('option-card-plugin.privilege.everyone') - fireEvent.click(everyoneOptions[0]) + fireEvent.click(everyoneOptions[0]!) // Save fireEvent.click(screen.getByText('common.operation.save')) @@ -717,7 +766,7 @@ describe('reference-setting-modal', () => { // Assert const modal = screen.getByTestId('modal') - expect(modal).toHaveClass('w-[620px]', 'max-w-[620px]', 'p-0!') + expect(modal)!.toHaveClass('w-[620px]', 'max-w-[620px]', 'p-0!') }) it('should pass isShow=true to Modal', () => { @@ -725,7 +774,8 @@ describe('reference-setting-modal', () => { render() // Assert - modal should be visible - expect(screen.getByTestId('modal')).toBeInTheDocument() + // Assert - modal should be visible + expect(screen.getByTestId('modal'))!.toBeInTheDocument() }) }) @@ -736,8 +786,8 @@ describe('reference-setting-modal', () => { // Assert - check order by getting all section labels const labels = screen.getAllByText(/plugin\.privilege\.whoCan/) - expect(labels[0]).toHaveTextContent('plugin.privilege.whoCanInstall') - expect(labels[1]).toHaveTextContent('plugin.privilege.whoCanDebug') + expect(labels[0])!.toHaveTextContent('plugin.privilege.whoCanInstall') + expect(labels[1])!.toHaveTextContent('plugin.privilege.whoCanDebug') }) it('should render three options per permission section', () => { @@ -762,8 +812,8 @@ describe('reference-setting-modal', () => { const cancelButton = screen.getByText('common.operation.cancel') const saveButton = screen.getByText('common.operation.save') - expect(cancelButton).toBeInTheDocument() - expect(saveButton).toBeInTheDocument() + expect(cancelButton)!.toBeInTheDocument() + expect(saveButton)!.toBeInTheDocument() }) }) }) @@ -797,11 +847,11 @@ describe('reference-setting-modal', () => { // Change install permission to Everyone const everyoneOptions = screen.getAllByTestId('option-card-plugin.privilege.everyone') - fireEvent.click(everyoneOptions[0]) + fireEvent.click(everyoneOptions[0]!) // Change debug permission to Admins Only const adminOptions = screen.getAllByTestId('option-card-plugin.privilege.admins') - fireEvent.click(adminOptions[1]) + fireEvent.click(adminOptions[1]!) // Change auto-update strategy fireEvent.click(screen.getByTestId('auto-update-change')) @@ -841,7 +891,7 @@ describe('reference-setting-modal', () => { // Make some changes const noOneOptions = screen.getAllByTestId('option-card-plugin.privilege.noone') - fireEvent.click(noOneOptions[0]) + fireEvent.click(noOneOptions[0]!) // Cancel fireEvent.click(screen.getByText('common.operation.cancel')) @@ -863,8 +913,9 @@ describe('reference-setting-modal', () => { render() // Assert - Labels are rendered correctly - expect(screen.getByText('plugin.privilege.whoCanInstall')).toBeInTheDocument() - expect(screen.getByText('plugin.privilege.whoCanDebug')).toBeInTheDocument() + // Assert - Labels are rendered correctly + expect(screen.getByText('plugin.privilege.whoCanInstall'))!.toBeInTheDocument() + expect(screen.getByText('plugin.privilege.whoCanDebug'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/editor/form/index.tsx b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/index.tsx index de003e78ba..3854b6561a 100644 --- a/web/app/components/rag-pipeline/components/panel/input-field/editor/form/index.tsx +++ b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/index.tsx @@ -28,7 +28,7 @@ const InputFieldForm = ({ initialData, supportFile = false, onCancel, onSubmit, if (!result.success) { const issues = result.error.issues const firstIssue = issues[0] - const errorMessage = `"${firstIssue.path.join('.')}" ${firstIssue.message}` + const errorMessage = `"${firstIssue!.path.join('.')}" ${firstIssue!.message}` toast.error(errorMessage) return errorMessage } diff --git a/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/options.tsx b/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/options.tsx index 49a1b9a284..7e6baa42b1 100644 --- a/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/options.tsx +++ b/web/app/components/rag-pipeline/components/panel/test-run/preparation/document-processing/options.tsx @@ -21,7 +21,7 @@ const Options = ({ initialData, configurations, schema, CustomActions, onSubmit if (!result.success) { const issues = result.error.issues const firstIssue = issues[0] - const errorMessage = `Path: ${firstIssue.path.join('.')} Error: ${firstIssue.message}` + const errorMessage = `Path: ${firstIssue!.path.join('.')} Error: ${firstIssue!.message}` toast.error(errorMessage) return errorMessage } @@ -41,7 +41,7 @@ const Options = ({ initialData, configurations, schema, CustomActions, onSubmit form.handleSubmit() }} > -
+
{configurations.map((config, index) => { const FieldComponent = BaseField({ initialData, diff --git a/web/app/components/tools/marketplace/index.tsx b/web/app/components/tools/marketplace/index.tsx index 449f0b8b4a..e7d6a22d42 100644 --- a/web/app/components/tools/marketplace/index.tsx +++ b/web/app/components/tools/marketplace/index.tsx @@ -37,50 +37,50 @@ const Marketplace = ({ return ( <> -
+
{isMarketplaceArrowVisible && ( )} -
-
+
+
{t('marketplace.moreFrom', { ns: 'plugin' })}
-
+
{t('marketplace.discover', { ns: 'plugin' })} - + {t('category.models', { ns: 'plugin' })} , - + {t('category.tools', { ns: 'plugin' })} , - + {t('category.datasources', { ns: 'plugin' })} , - + {t('category.triggers', { ns: 'plugin' })} , - + {t('category.agents', { ns: 'plugin' })} , - + {t('category.extensions', { ns: 'plugin' })} {t('marketplace.and', { ns: 'plugin' })} - + {t('category.bundles', { ns: 'plugin' })} {t('operation.in', { ns: 'common' })} {t('marketplace.difyMarketplace', { ns: 'plugin' })} @@ -92,7 +92,7 @@ const Marketplace = ({
{ isLoading && page === 1 && ( -
+
) diff --git a/web/app/components/tools/mcp/detail/__tests__/content.spec.tsx b/web/app/components/tools/mcp/detail/__tests__/content.spec.tsx index 584c9d211a..5216e9eede 100644 --- a/web/app/components/tools/mcp/detail/__tests__/content.spec.tsx +++ b/web/app/components/tools/mcp/detail/__tests__/content.spec.tsx @@ -199,22 +199,22 @@ describe('MCPDetailContent', () => { describe('Rendering', () => { it('should render without crashing', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('Test MCP Server')).toBeInTheDocument() + expect(screen.getByText('Test MCP Server'))!.toBeInTheDocument() }) it('should display MCP name', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('Test MCP Server')).toBeInTheDocument() + expect(screen.getByText('Test MCP Server'))!.toBeInTheDocument() }) it('should display server identifier', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('test-mcp')).toBeInTheDocument() + expect(screen.getByText('test-mcp'))!.toBeInTheDocument() }) it('should display server URL', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('https://example.com/mcp')).toBeInTheDocument() + expect(screen.getByText('https://example.com/mcp'))!.toBeInTheDocument() }) it('should render close button', () => { @@ -227,7 +227,8 @@ describe('MCPDetailContent', () => { it('should render operation dropdown', () => { render(, { wrapper: createWrapper() }) // Operation dropdown trigger should be present - expect(document.querySelector('button')).toBeInTheDocument() + // Operation dropdown trigger should be present + expect(document.querySelector('button'))!.toBeInTheDocument() }) }) @@ -238,7 +239,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('tools.mcp.authorize')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.authorize'))!.toBeInTheDocument() }) it('should show authorized button when authorized', () => { @@ -247,7 +248,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('tools.auth.authorized')).toBeInTheDocument() + expect(screen.getByText('tools.auth.authorized'))!.toBeInTheDocument() }) it('should show authorization required message when not authorized', () => { @@ -256,7 +257,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('tools.mcp.authorizingRequired')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.authorizingRequired'))!.toBeInTheDocument() }) it('should show authorization tip', () => { @@ -265,7 +266,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('tools.mcp.authorizeTip')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.authorizeTip'))!.toBeInTheDocument() }) }) @@ -276,7 +277,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('tools.mcp.toolsEmpty')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.toolsEmpty'))!.toBeInTheDocument() }) it('should show get tools button when empty', () => { @@ -285,7 +286,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('tools.mcp.getTools')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.getTools'))!.toBeInTheDocument() }) }) @@ -294,7 +295,7 @@ describe('MCPDetailContent', () => { render(, { wrapper: createWrapper() }) // Icon container should be present const iconContainer = document.querySelector('[class*="rounded-xl"][class*="border"]') - expect(iconContainer).toBeInTheDocument() + expect(iconContainer)!.toBeInTheDocument() }) }) @@ -305,7 +306,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('Test MCP Server')).toBeInTheDocument() + expect(screen.getByText('Test MCP Server'))!.toBeInTheDocument() }) it('should handle long MCP name', () => { @@ -315,7 +316,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText(longName)).toBeInTheDocument() + expect(screen.getByText(longName))!.toBeInTheDocument() }) }) @@ -332,8 +333,8 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('tool1')).toBeInTheDocument() - expect(screen.getByText('tool2')).toBeInTheDocument() + expect(screen.getByText('tool1'))!.toBeInTheDocument() + expect(screen.getByText('tool2'))!.toBeInTheDocument() }) it('should show single tool label when only one tool', () => { @@ -345,7 +346,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('tools.mcp.onlyTool')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.onlyTool'))!.toBeInTheDocument() }) it('should show tools count when multiple tools', () => { @@ -360,7 +361,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText(/tools.mcp.toolsNum/)).toBeInTheDocument() + expect(screen.getByText(/tools.mcp.toolsNum/))!.toBeInTheDocument() }) }) @@ -375,7 +376,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('tools.mcp.gettingTools')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.gettingTools'))!.toBeInTheDocument() }) it('should show updating state when updating tools', () => { @@ -388,7 +389,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('tools.mcp.updateTools')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.updateTools'))!.toBeInTheDocument() }) it('should show authorizing button when authorizing', () => { @@ -464,7 +465,7 @@ describe('MCPDetailContent', () => { ) const authorizeBtn = screen.getByText('tools.mcp.authorize') - expect(authorizeBtn.closest('button')).toBeDisabled() + expect(authorizeBtn.closest('button'))!.toBeDisabled() }) }) @@ -483,7 +484,7 @@ describe('MCPDetailContent', () => { fireEvent.click(updateBtn) await waitFor(() => { - expect(screen.getByText('tools.mcp.toolUpdateConfirmTitle')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.toolUpdateConfirmTitle'))!.toBeInTheDocument() }) }) @@ -503,7 +504,7 @@ describe('MCPDetailContent', () => { fireEvent.click(updateBtn) await waitFor(() => { - expect(screen.getByText('tools.mcp.toolUpdateConfirmTitle')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.toolUpdateConfirmTitle'))!.toBeInTheDocument() }) // Confirm the update @@ -541,7 +542,7 @@ describe('MCPDetailContent', () => { fireEvent.click(editBtn) await waitFor(() => { - expect(screen.getByTestId('mcp-update-modal')).toBeInTheDocument() + expect(screen.getByTestId('mcp-update-modal'))!.toBeInTheDocument() }) }) @@ -553,7 +554,7 @@ describe('MCPDetailContent', () => { fireEvent.click(editBtn) await waitFor(() => { - expect(screen.getByTestId('mcp-update-modal')).toBeInTheDocument() + expect(screen.getByTestId('mcp-update-modal'))!.toBeInTheDocument() }) // Close modal @@ -574,7 +575,7 @@ describe('MCPDetailContent', () => { fireEvent.click(editBtn) await waitFor(() => { - expect(screen.getByTestId('mcp-update-modal')).toBeInTheDocument() + expect(screen.getByTestId('mcp-update-modal'))!.toBeInTheDocument() }) // Confirm form @@ -601,7 +602,7 @@ describe('MCPDetailContent', () => { fireEvent.click(editBtn) await waitFor(() => { - expect(screen.getByTestId('mcp-update-modal')).toBeInTheDocument() + expect(screen.getByTestId('mcp-update-modal'))!.toBeInTheDocument() }) // Confirm form @@ -624,7 +625,7 @@ describe('MCPDetailContent', () => { fireEvent.click(removeBtn) await waitFor(() => { - expect(screen.getByText('tools.mcp.delete')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.delete'))!.toBeInTheDocument() }) }) @@ -636,7 +637,7 @@ describe('MCPDetailContent', () => { fireEvent.click(removeBtn) await waitFor(() => { - expect(screen.getByText('tools.mcp.delete')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.delete'))!.toBeInTheDocument() }) // Cancel @@ -656,7 +657,7 @@ describe('MCPDetailContent', () => { fireEvent.click(removeBtn) await waitFor(() => { - expect(screen.getByText('tools.mcp.delete')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.delete'))!.toBeInTheDocument() }) // Confirm delete @@ -678,7 +679,7 @@ describe('MCPDetailContent', () => { fireEvent.click(removeBtn) await waitFor(() => { - expect(screen.getByText('tools.mcp.delete')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.delete'))!.toBeInTheDocument() }) // Confirm delete @@ -743,7 +744,7 @@ describe('MCPDetailContent', () => { }) // Get the callback function and call it - const oauthCallback = mockOpenOAuthPopup.mock.calls[0][1] + const oauthCallback = mockOpenOAuthPopup.mock.calls[0]![1] oauthCallback() await waitFor(() => { @@ -765,7 +766,7 @@ describe('MCPDetailContent', () => { // Button should be disabled const authorizeBtn = screen.getByText('tools.mcp.authorize') - expect(authorizeBtn.closest('button')).toBeDisabled() + expect(authorizeBtn.closest('button'))!.toBeDisabled() }) }) @@ -776,7 +777,7 @@ describe('MCPDetailContent', () => { , { wrapper: createWrapper() }, ) - expect(screen.getByText('tools.auth.authorized')).toBeInTheDocument() + expect(screen.getByText('tools.auth.authorized'))!.toBeInTheDocument() }) it('should call handleAuthorize when authorized button is clicked', async () => { @@ -805,7 +806,7 @@ describe('MCPDetailContent', () => { ) const authorizedBtn = screen.getByText('tools.auth.authorized') - expect(authorizedBtn.closest('button')).toBeDisabled() + expect(authorizedBtn.closest('button'))!.toBeDisabled() }) }) @@ -825,7 +826,7 @@ describe('MCPDetailContent', () => { fireEvent.click(updateBtn) await waitFor(() => { - expect(screen.getByText('tools.mcp.toolUpdateConfirmTitle')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.toolUpdateConfirmTitle'))!.toBeInTheDocument() }) // Cancel the update diff --git a/web/app/components/tools/provider/__tests__/detail.spec.tsx b/web/app/components/tools/provider/__tests__/detail.spec.tsx index 1e3327232c..870e8adab1 100644 --- a/web/app/components/tools/provider/__tests__/detail.spec.tsx +++ b/web/app/components/tools/provider/__tests__/detail.spec.tsx @@ -187,9 +187,9 @@ describe('ProviderDetail', () => { onRefreshData={mockOnRefreshData} />, ) - expect(screen.getByTestId('title')).toHaveTextContent('Test Collection') - expect(screen.getByTestId('org-info')).toHaveTextContent('Test Author') - expect(screen.getByTestId('description')).toHaveTextContent('A test collection') + expect(screen.getByTestId('title'))!.toHaveTextContent('Test Collection') + expect(screen.getByTestId('org-info'))!.toHaveTextContent('Test Author') + expect(screen.getByTestId('description'))!.toHaveTextContent('A test collection') }) it('shows loading state initially', () => { @@ -200,7 +200,7 @@ describe('ProviderDetail', () => { onRefreshData={mockOnRefreshData} />, ) - expect(screen.getByRole('status')).toBeInTheDocument() + expect(screen.getByRole('status'))!.toBeInTheDocument() }) it('renders tool list after loading for builtIn type', async () => { @@ -212,8 +212,8 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByTestId('tool-tool-1')).toBeInTheDocument() - expect(screen.getByTestId('tool-tool-2')).toBeInTheDocument() + expect(screen.getByTestId('tool-tool-1'))!.toBeInTheDocument() + expect(screen.getByTestId('tool-tool-2'))!.toBeInTheDocument() }) }) @@ -239,7 +239,7 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.auth.unauthorized')).toBeInTheDocument() + expect(screen.getByText('tools.auth.unauthorized'))!.toBeInTheDocument() }) }) @@ -252,7 +252,7 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.auth.authorized')).toBeInTheDocument() + expect(screen.getByText('tools.auth.authorized'))!.toBeInTheDocument() }) }) }) @@ -273,7 +273,7 @@ describe('ProviderDetail', () => { expect(mockFetchCustomCollection).toHaveBeenCalledWith('test-collection') }) await waitFor(() => { - expect(screen.getByText('tools.createTool.editAction')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.editAction'))!.toBeInTheDocument() }) }) }) @@ -291,8 +291,8 @@ describe('ProviderDetail', () => { expect(mockFetchWorkflowToolDetail).toHaveBeenCalledWith('test-id') }) await waitFor(() => { - expect(screen.getByText('tools.openInStudio')).toBeInTheDocument() - expect(screen.getByText('tools.createTool.editAction')).toBeInTheDocument() + expect(screen.getByText('tools.openInStudio'))!.toBeInTheDocument() + expect(screen.getByText('tools.createTool.editAction'))!.toBeInTheDocument() }) }) }) @@ -315,7 +315,7 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.auth.unauthorized')).toBeInTheDocument() + expect(screen.getByText('tools.auth.unauthorized'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.auth.unauthorized')) expect(mockSetShowModelModal).toHaveBeenCalled() @@ -332,7 +332,7 @@ describe('ProviderDetail', () => { />, ) const buttons = screen.getAllByRole('button') - fireEvent.click(buttons[0]) + fireEvent.click(buttons[0]!) expect(mockOnHide).toHaveBeenCalled() }) }) @@ -388,10 +388,10 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.auth.unauthorized')).toBeInTheDocument() + expect(screen.getByText('tools.auth.unauthorized'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.auth.unauthorized')) - expect(screen.getByTestId('config-credential')).toBeInTheDocument() + expect(screen.getByTestId('config-credential'))!.toBeInTheDocument() }) it('saves credentials and refreshes data', async () => { @@ -403,7 +403,7 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.auth.unauthorized')).toBeInTheDocument() + expect(screen.getByText('tools.auth.unauthorized'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.auth.unauthorized')) await act(async () => { @@ -424,7 +424,7 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.auth.unauthorized')).toBeInTheDocument() + expect(screen.getByText('tools.auth.unauthorized'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.auth.unauthorized')) await act(async () => { @@ -445,10 +445,10 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.auth.authorized')).toBeInTheDocument() + expect(screen.getByText('tools.auth.authorized'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.auth.authorized')) - expect(screen.getByTestId('config-credential')).toBeInTheDocument() + expect(screen.getByTestId('config-credential'))!.toBeInTheDocument() }) }) @@ -467,10 +467,10 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.auth.unauthorized')).toBeInTheDocument() + expect(screen.getByText('tools.auth.unauthorized'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.auth.unauthorized')) - const call = mockSetShowModelModal.mock.calls[0][0] + const call = mockSetShowModelModal.mock.calls[0]![0] act(() => { call.onSaveCallback() }) @@ -497,7 +497,7 @@ describe('ProviderDetail', () => { expect(mockFetchCustomCollection).toHaveBeenCalled() }) await waitFor(() => { - expect(screen.getByText('tools.createTool.editAction')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.editAction'))!.toBeInTheDocument() }) }) @@ -513,10 +513,10 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.createTool.editAction')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.editAction'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.createTool.editAction')) - expect(screen.getByTestId('edit-custom-modal')).toBeInTheDocument() + expect(screen.getByTestId('edit-custom-modal'))!.toBeInTheDocument() await act(async () => { fireEvent.click(screen.getByTestId('edit-save')) }) @@ -538,11 +538,11 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.createTool.editAction')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.editAction'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.createTool.editAction')) fireEvent.click(screen.getByTestId('edit-remove')) - expect(screen.getByText('tools.createTool.deleteToolConfirmTitle')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.deleteToolConfirmTitle'))!.toBeInTheDocument() await act(async () => { fireEvent.click(getDeleteConfirmButton()) }) @@ -574,10 +574,10 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('query')).toBeInTheDocument() - expect(screen.getByText('string')).toBeInTheDocument() - expect(screen.getByText('Search query')).toBeInTheDocument() - expect(screen.getByText('limit')).toBeInTheDocument() + expect(screen.getByText('query'))!.toBeInTheDocument() + expect(screen.getByText('string'))!.toBeInTheDocument() + expect(screen.getByText('Search query'))!.toBeInTheDocument() + expect(screen.getByText('limit'))!.toBeInTheDocument() }) }) @@ -590,10 +590,10 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.createTool.editAction')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.editAction'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.createTool.editAction')) - expect(screen.getByTestId('workflow-tool-modal')).toBeInTheDocument() + expect(screen.getByTestId('workflow-tool-modal'))!.toBeInTheDocument() await act(async () => { fireEvent.click(screen.getByTestId('wf-save')) }) @@ -612,11 +612,11 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.createTool.editAction')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.editAction'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.createTool.editAction')) fireEvent.click(screen.getByTestId('wf-remove')) - expect(screen.getByText('tools.createTool.deleteToolConfirmTitle')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.deleteToolConfirmTitle'))!.toBeInTheDocument() await act(async () => { fireEvent.click(getDeleteConfirmButton()) }) @@ -637,10 +637,10 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.auth.unauthorized')).toBeInTheDocument() + expect(screen.getByText('tools.auth.unauthorized'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.auth.unauthorized')) - expect(screen.getByTestId('config-credential')).toBeInTheDocument() + expect(screen.getByTestId('config-credential'))!.toBeInTheDocument() fireEvent.click(screen.getByTestId('credential-cancel')) expect(screen.queryByTestId('config-credential')).not.toBeInTheDocument() }) @@ -657,10 +657,10 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.createTool.editAction')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.editAction'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.createTool.editAction')) - expect(screen.getByTestId('edit-custom-modal')).toBeInTheDocument() + expect(screen.getByTestId('edit-custom-modal'))!.toBeInTheDocument() fireEvent.click(screen.getByTestId('edit-close')) expect(screen.queryByTestId('edit-custom-modal')).not.toBeInTheDocument() }) @@ -674,10 +674,10 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.createTool.editAction')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.editAction'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.createTool.editAction')) - expect(screen.getByTestId('workflow-tool-modal')).toBeInTheDocument() + expect(screen.getByTestId('workflow-tool-modal'))!.toBeInTheDocument() fireEvent.click(screen.getByTestId('wf-close')) expect(screen.queryByTestId('workflow-tool-modal')).not.toBeInTheDocument() }) @@ -696,11 +696,11 @@ describe('ProviderDetail', () => { />, ) await waitFor(() => { - expect(screen.getByText('tools.createTool.editAction')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.editAction'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('tools.createTool.editAction')) fireEvent.click(screen.getByTestId('edit-remove')) - expect(screen.getByText('tools.createTool.deleteToolConfirmTitle')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.deleteToolConfirmTitle'))!.toBeInTheDocument() fireEvent.click(getDeleteCancelButton()) await waitFor(() => { expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument() diff --git a/web/app/components/tools/workflow-tool/__tests__/method-selector.spec.tsx b/web/app/components/tools/workflow-tool/__tests__/method-selector.spec.tsx index 4379bec035..d1126bf762 100644 --- a/web/app/components/tools/workflow-tool/__tests__/method-selector.spec.tsx +++ b/web/app/components/tools/workflow-tool/__tests__/method-selector.spec.tsx @@ -26,26 +26,28 @@ describe('MethodSelector', () => { renderComponent() // Should display the current method text - expect(screen.getByText('tools.createTool.toolInput.methodParameter')).toBeInTheDocument() + // Should display the current method text + expect(screen.getByText('tools.createTool.toolInput.methodParameter'))!.toBeInTheDocument() }) it('should render with llm value selected', () => { renderComponent({ value: 'llm' }) - expect(screen.getByText('tools.createTool.toolInput.methodParameter')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodParameter'))!.toBeInTheDocument() }) it('should render with form value selected', () => { renderComponent({ value: 'form' }) - expect(screen.getByText('tools.createTool.toolInput.methodSetting')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodSetting'))!.toBeInTheDocument() }) it('should render with undefined value', () => { renderComponent({ value: undefined }) // When value is undefined, it should show the form method text (else branch) - expect(screen.getByText('tools.createTool.toolInput.methodSetting')).toBeInTheDocument() + // When value is undefined, it should show the form method text (else branch) + expect(screen.getByText('tools.createTool.toolInput.methodSetting'))!.toBeInTheDocument() }) it('should render arrow down icon', () => { @@ -53,7 +55,7 @@ describe('MethodSelector', () => { // The arrow icon is rendered with remixicon const arrowIcon = document.querySelector('.remixicon') - expect(arrowIcon).toBeInTheDocument() + expect(arrowIcon)!.toBeInTheDocument() }) }) @@ -62,19 +64,19 @@ describe('MethodSelector', () => { it('should display methodParameter when value is llm', () => { renderComponent({ value: 'llm' }) - expect(screen.getByText('tools.createTool.toolInput.methodParameter')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodParameter'))!.toBeInTheDocument() }) it('should display methodSetting when value is form', () => { renderComponent({ value: 'form' }) - expect(screen.getByText('tools.createTool.toolInput.methodSetting')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodSetting'))!.toBeInTheDocument() }) it('should handle empty string value as non-llm', () => { renderComponent({ value: '' }) - expect(screen.getByText('tools.createTool.toolInput.methodSetting')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodSetting'))!.toBeInTheDocument() }) }) @@ -90,8 +92,8 @@ describe('MethodSelector', () => { // Dropdown should now show both options with tips await waitFor(() => { - expect(screen.getByText('tools.createTool.toolInput.methodParameterTip')).toBeInTheDocument() - expect(screen.getByText('tools.createTool.toolInput.methodSettingTip')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodParameterTip'))!.toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodSettingTip'))!.toBeInTheDocument() }) }) @@ -106,12 +108,12 @@ describe('MethodSelector', () => { // Wait for dropdown to open await waitFor(() => { - expect(screen.getByText('tools.createTool.toolInput.methodParameterTip')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodParameterTip'))!.toBeInTheDocument() }) // Click the llm option (by finding the method parameter option in dropdown) const llmOption = screen.getAllByText('tools.createTool.toolInput.methodParameter')[0] - await user.click(llmOption) + await user.click(llmOption!) expect(onChange).toHaveBeenCalledWith('llm') }) @@ -127,12 +129,12 @@ describe('MethodSelector', () => { // Wait for dropdown to open await waitFor(() => { - expect(screen.getByText('tools.createTool.toolInput.methodSettingTip')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodSettingTip'))!.toBeInTheDocument() }) // Click the form option (by finding the method setting option in dropdown) const formOption = screen.getAllByText('tools.createTool.toolInput.methodSetting')[0] - await user.click(formOption) + await user.click(formOption!) expect(onChange).toHaveBeenCalledWith('form') }) @@ -146,7 +148,7 @@ describe('MethodSelector', () => { // First click - open await user.click(trigger) await waitFor(() => { - expect(screen.getByText('tools.createTool.toolInput.methodParameterTip')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodParameterTip'))!.toBeInTheDocument() }) // Second click - close @@ -163,7 +165,7 @@ describe('MethodSelector', () => { renderComponent() const trigger = document.querySelector('.hover\\:bg-background-section-burn') - expect(trigger).toBeInTheDocument() + expect(trigger)!.toBeInTheDocument() }) it('should apply open state styles when dropdown is open', async () => { @@ -175,7 +177,7 @@ describe('MethodSelector', () => { await waitFor(() => { const openTrigger = document.querySelector('.bg-background-section-burn\\!') - expect(openTrigger).toBeInTheDocument() + expect(openTrigger)!.toBeInTheDocument() }) }) @@ -189,7 +191,7 @@ describe('MethodSelector', () => { await waitFor(() => { // Check icon should be visible for llm option const checkIcon = document.querySelector('.text-text-accent') - expect(checkIcon).toBeInTheDocument() + expect(checkIcon)!.toBeInTheDocument() }) }) @@ -203,7 +205,7 @@ describe('MethodSelector', () => { await waitFor(() => { // Check icon should be visible for form option const checkIcon = document.querySelector('.text-text-accent') - expect(checkIcon).toBeInTheDocument() + expect(checkIcon)!.toBeInTheDocument() }) }) }) @@ -219,8 +221,9 @@ describe('MethodSelector', () => { await waitFor(() => { // Should show both option titles and descriptions - expect(screen.getByText('tools.createTool.toolInput.methodParameterTip')).toBeInTheDocument() - expect(screen.getByText('tools.createTool.toolInput.methodSettingTip')).toBeInTheDocument() + // Should show both option titles and descriptions + expect(screen.getByText('tools.createTool.toolInput.methodParameterTip'))!.toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodSettingTip'))!.toBeInTheDocument() }) }) @@ -233,9 +236,9 @@ describe('MethodSelector', () => { await waitFor(() => { const dropdown = document.querySelector('.w-\\[320px\\]') - expect(dropdown).toBeInTheDocument() - expect(dropdown).toHaveClass('rounded-lg') - expect(dropdown).toHaveClass('shadow-lg') + expect(dropdown)!.toBeInTheDocument() + expect(dropdown)!.toHaveClass('rounded-lg') + expect(dropdown)!.toHaveClass('shadow-lg') }) }) @@ -267,7 +270,8 @@ describe('MethodSelector', () => { await user.click(trigger) // Should not crash and should be in a consistent state - expect(trigger).toBeInTheDocument() + // Should not crash and should be in a consistent state + expect(trigger)!.toBeInTheDocument() }) it('should handle selecting the already selected value', async () => { @@ -279,12 +283,12 @@ describe('MethodSelector', () => { await user.click(trigger) await waitFor(() => { - expect(screen.getByText('tools.createTool.toolInput.methodParameterTip')).toBeInTheDocument() + expect(screen.getByText('tools.createTool.toolInput.methodParameterTip'))!.toBeInTheDocument() }) // Click the llm option in the dropdown (the one with the tip text nearby) const llmOptionContainer = screen.getByText('tools.createTool.toolInput.methodParameterTip').closest('.cursor-pointer') - expect(llmOptionContainer).toBeInTheDocument() + expect(llmOptionContainer)!.toBeInTheDocument() await user.click(llmOptionContainer!) // Should call onChange @@ -298,7 +302,7 @@ describe('MethodSelector', () => { renderComponent() const trigger = document.querySelector('.cursor-pointer') - expect(trigger).toBeInTheDocument() + expect(trigger)!.toBeInTheDocument() }) it('should have clickable dropdown options', async () => { diff --git a/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.tsx b/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.tsx index 8b1ce699e7..07eb8ffe3c 100644 --- a/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.tsx +++ b/web/app/components/workflow-app/components/workflow-onboarding-modal/start-node-option.tsx @@ -30,7 +30,7 @@ const StartNodeOption: FC = ({

{title} {subtitle && ( - + {' '} {subtitle} @@ -39,7 +39,7 @@ const StartNodeOption: FC = ({

-

+

{description}

diff --git a/web/app/components/workflow/header/__tests__/test-run-menu.spec.tsx b/web/app/components/workflow/header/__tests__/test-run-menu.spec.tsx index 40387d1e0e..7e4cbe87c6 100644 --- a/web/app/components/workflow/header/__tests__/test-run-menu.spec.tsx +++ b/web/app/components/workflow/header/__tests__/test-run-menu.spec.tsx @@ -41,7 +41,7 @@ vi.mock('@/app/components/base/ui/dropdown-menu', async () => { return open ?
{children}
: null }, DropdownMenuGroup: ({ children }: { children: React.ReactNode }) =>
{children}
, - DropdownMenuGroupLabel: ({ children, className }: { children: React.ReactNode, className?: string }) =>
{children}
, + DropdownMenuLabel: ({ children, className }: { children: React.ReactNode, className?: string }) =>
{children}
, DropdownMenuSeparator: ({ className }: { className?: string }) =>
, DropdownMenuItem: ({ children, onClick, className }: { children: React.ReactNode, onClick?: React.MouseEventHandler, className?: string }) => { const { setOpen } = useDropdownMenuContext() diff --git a/web/app/components/workflow/header/running-title.tsx b/web/app/components/workflow/header/running-title.tsx index 590c6cd329..e3e2ebab75 100644 --- a/web/app/components/workflow/header/running-title.tsx +++ b/web/app/components/workflow/header/running-title.tsx @@ -15,7 +15,7 @@ const RunningTitle = () => { {isChatMode ? `Test Chat${formatWorkflowRunIdentifier(historyWorkflowData?.finished_at)}` : `Test Run${formatWorkflowRunIdentifier(historyWorkflowData?.finished_at)}`} · - + {t('common.viewOnly', { ns: 'workflow' })}
diff --git a/web/app/components/workflow/header/test-run-menu.tsx b/web/app/components/workflow/header/test-run-menu.tsx index 5b86c3c3f5..6540875e6b 100644 --- a/web/app/components/workflow/header/test-run-menu.tsx +++ b/web/app/components/workflow/header/test-run-menu.tsx @@ -1,7 +1,7 @@ import type { ShortcutMapping } from './test-run-menu-helpers' import { forwardRef, useCallback, useImperativeHandle, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuGroupLabel, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu' +import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuLabel, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu' import { OptionRow, SingleOptionTrigger, useShortcutMenu } from './test-run-menu-helpers' export enum TriggerType { @@ -155,9 +155,9 @@ const TestRunMenu = forwardRef(({ popupClassName="w-[284px] p-1" > - + {t('common.chooseStartNodeToRun', { ns: 'workflow' })} - +
{hasUserInput && renderOption(options.userInput!)} diff --git a/web/app/components/workflow/hooks/use-checklist.ts b/web/app/components/workflow/hooks/use-checklist.ts index 99536653ce..abafcfd8eb 100644 --- a/web/app/components/workflow/hooks/use-checklist.ts +++ b/web/app/components/workflow/hooks/use-checklist.ts @@ -187,18 +187,18 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { let moreDataForCheckValid let usedVars: ValueSelector[] = [] - if (node.data.type === BlockEnum.Tool) - moreDataForCheckValid = getToolCheckParams(node.data as ToolNodeType, buildInTools || [], customTools || [], workflowTools || [], language) + if (node!.data.type === BlockEnum.Tool) + moreDataForCheckValid = getToolCheckParams(node!.data as ToolNodeType, buildInTools || [], customTools || [], workflowTools || [], language) - if (node.data.type === BlockEnum.DataSource) - moreDataForCheckValid = getDataSourceCheckParams(node.data as DataSourceNodeType, dataSourceList || [], language) + if (node!.data.type === BlockEnum.DataSource) + moreDataForCheckValid = getDataSourceCheckParams(node!.data as DataSourceNodeType, dataSourceList || [], language) - if (node.data.type === BlockEnum.TriggerPlugin) - moreDataForCheckValid = getTriggerCheckParams(node.data as PluginTriggerNodeType, triggerPlugins, language) + if (node!.data.type === BlockEnum.TriggerPlugin) + moreDataForCheckValid = getTriggerCheckParams(node!.data as PluginTriggerNodeType, triggerPlugins, language) - const toolIcon = getToolIcon(node.data) - if (node.data.type === BlockEnum.Agent) { - const data = node.data as AgentNodeType + const toolIcon = getToolIcon(node!.data) + if (node!.data.type === BlockEnum.Agent) { + const data = node!.data as AgentNodeType const isReadyForCheckValid = !!strategyProviders const provider = strategyProviders?.find(provider => provider.declaration.identity.name === data.agent_strategy_provider_name) const strategy = provider?.declaration.strategies?.find(s => s.identity.name === data.agent_strategy_name) @@ -210,13 +210,13 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { } } else { - usedVars = getNodeUsedVars(node).filter(v => v.length > 0) + usedVars = getNodeUsedVars(node!).filter(v => v.length > 0) } - if (node.type === CUSTOM_NODE) { - const checkData = getCheckData(node.data) - const validator = nodesExtraData?.[node.data.type as BlockEnum]?.checkValid - const isPluginMissing = isNodePluginMissing(node.data, { builtInTools: buildInTools, customTools, workflowTools, mcpTools, triggerPlugins, dataSourceList }) + if (node!.type === CUSTOM_NODE) { + const checkData = getCheckData(node!.data) + const validator = nodesExtraData?.[node!.data.type as BlockEnum]?.checkValid + const isPluginMissing = isNodePluginMissing(node!.data, { builtInTools: buildInTools, customTools, workflowTools, mcpTools, triggerPlugins, dataSourceList }) const errorMessages: string[] = [] @@ -224,8 +224,8 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { errorMessages.push(t('nodes.common.pluginNotInstalled', { ns: 'workflow' })) } else { - if (node.data.type === BlockEnum.LLM) { - const modelProvider = (node.data as CommonNodeType<{ model?: ModelConfig }>).model?.provider + if (node!.data.type === BlockEnum.LLM) { + const modelProvider = (node!.data as CommonNodeType<{ model?: ModelConfig }>).model?.provider const modelIssue = getLLMModelIssue({ modelProvider, isModelProviderInstalled: isLLMModelProviderInstalled(modelProvider, installedPluginIds), @@ -240,12 +240,12 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { errorMessages.push(validationError) } - const availableVars = map[node.id].availableVars + const availableVars = map[node!.id]!.availableVars let hasInvalidVar = false for (const variable of usedVars) { if (hasInvalidVar) break - if (isSpecialVar(variable[0])) + if (isSpecialVar(variable[0]!)) continue const usedNode = availableVars.find(v => v.nodeId === variable?.[0]) if (!usedNode || !usedNode.vars.some(v => v.variable === variable?.[1])) @@ -255,17 +255,17 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { errorMessages.push(t('errorMsg.invalidVariable', { ns: 'workflow' })) } - const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false + const isStartNodeMeta = nodesExtraData?.[node!.data.type as BlockEnum]?.metaData.isStart ?? false const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true - const isUnconnected = !validNodes.some(n => n.id === node.id) + const isUnconnected = !validNodes.some(n => n.id === node!.id) const shouldShowError = errorMessages.length > 0 || (isUnconnected && !canSkipConnectionCheck) if (shouldShowError) { list.push({ - id: node.id, - type: node.data.type, - title: node.data.title, + id: node!.id, + type: node!.data.type, + title: node!.data.title, toolIcon, unConnected: isUnconnected && !canSkipConnectionCheck, errorMessages, @@ -273,7 +273,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { disableGoTo: isPluginMissing, isPluginMissing, pluginUniqueIdentifier: isPluginMissing - ? (node.data as { plugin_unique_identifier?: string }).plugin_unique_identifier + ? (node!.data as { plugin_unique_identifier?: string }).plugin_unique_identifier : undefined, }) } @@ -458,14 +458,14 @@ export const useChecklistBeforePublish = () => { const node = filteredNodes[i] let moreDataForCheckValid let usedVars: ValueSelector[] = [] - if (node.data.type === BlockEnum.Tool) - moreDataForCheckValid = getToolCheckParams(node.data as ToolNodeType, buildInTools || [], customTools || [], workflowTools || [], language) + if (node!.data.type === BlockEnum.Tool) + moreDataForCheckValid = getToolCheckParams(node!.data as ToolNodeType, buildInTools || [], customTools || [], workflowTools || [], language) - if (node.data.type === BlockEnum.DataSource) - moreDataForCheckValid = getDataSourceCheckParams(node.data as DataSourceNodeType, dataSourceList || [], language) + if (node!.data.type === BlockEnum.DataSource) + moreDataForCheckValid = getDataSourceCheckParams(node!.data as DataSourceNodeType, dataSourceList || [], language) - if (node.data.type === BlockEnum.Agent) { - const data = node.data as AgentNodeType + if (node!.data.type === BlockEnum.Agent) { + const data = node!.data as AgentNodeType const isReadyForCheckValid = !!strategyProviders const provider = strategyProviders?.find(provider => provider.declaration.identity.name === data.agent_strategy_provider_name) const strategy = provider?.declaration.strategies?.find(s => s.identity.name === data.agent_strategy_name) @@ -477,55 +477,55 @@ export const useChecklistBeforePublish = () => { } } else { - usedVars = getNodeUsedVars(node).filter(v => v.length > 0) + usedVars = getNodeUsedVars(node!).filter(v => v.length > 0) } - if (node.data.type === BlockEnum.LLM) { - const modelProvider = (node.data as CommonNodeType<{ model?: ModelConfig }>).model?.provider + if (node!.data.type === BlockEnum.LLM) { + const modelProvider = (node!.data as CommonNodeType<{ model?: ModelConfig }>).model?.provider const modelIssue = getLLMModelIssue({ modelProvider, isModelProviderInstalled: isLLMModelProviderInstalled(modelProvider, installedPluginIds), }) if (modelIssue === LLMModelIssueCode.providerPluginUnavailable) { - toast.error(`[${node.data.title}] ${t('errorMsg.configureModel', { ns: 'workflow' })}`) + toast.error(`[${node!.data.title}] ${t('errorMsg.configureModel', { ns: 'workflow' })}`) return false } } - const checkData = getCheckData(node.data, datasets, embeddingProviderModelMap) - const { errorMessage } = nodesExtraData![node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid) + const checkData = getCheckData(node!.data, datasets, embeddingProviderModelMap) + const { errorMessage } = nodesExtraData![node!.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid) if (errorMessage) { - toast.error(`[${node.data.title}] ${errorMessage}`) + toast.error(`[${node!.data.title}] ${errorMessage}`) return false } - const availableVars = map[node.id].availableVars + const availableVars = map[node!.id]!.availableVars for (const variable of usedVars) { - const isSpecialVars = isSpecialVar(variable[0]) + const isSpecialVars = isSpecialVar(variable[0]!) if (!isSpecialVars) { const usedNode = availableVars.find(v => v.nodeId === variable?.[0]) if (usedNode) { const usedVar = usedNode.vars.find(v => v.variable === variable?.[1]) if (!usedVar) { - toast.error(`[${node.data.title}] ${t('errorMsg.invalidVariable', { ns: 'workflow' })}`) + toast.error(`[${node!.data.title}] ${t('errorMsg.invalidVariable', { ns: 'workflow' })}`) return false } } else { - toast.error(`[${node.data.title}] ${t('errorMsg.invalidVariable', { ns: 'workflow' })}`) + toast.error(`[${node!.data.title}] ${t('errorMsg.invalidVariable', { ns: 'workflow' })}`) return false } } } - const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false + const isStartNodeMeta = nodesExtraData?.[node!.data.type as BlockEnum]?.metaData.isStart ?? false const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true - const isUnconnected = !validNodes.some(n => n.id === node.id) + const isUnconnected = !validNodes.some(n => n.id === node!.id) if (isUnconnected && !canSkipConnectionCheck) { - toast.error(`[${node.data.title}] ${t('common.needConnectTip', { ns: 'workflow' })}`) + toast.error(`[${node!.data.title}] ${t('common.needConnectTip', { ns: 'workflow' })}`) return false } } diff --git a/web/app/components/workflow/nodes/_base/components/editor/text-editor.tsx b/web/app/components/workflow/nodes/_base/components/editor/text-editor.tsx index 1cd18f2907..3de713c067 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/text-editor.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/text-editor.tsx @@ -53,7 +53,7 @@ const TextEditor: FC = ({ onChange={e => onChange(e.target.value)} onFocus={setIsFocus} onBlur={handleBlur} - className="h-full w-full resize-none border-none bg-transparent px-3 text-[13px] font-normal leading-[18px] text-gray-900 placeholder:text-gray-300 focus:outline-hidden" + className="h-full w-full resize-none border-none bg-transparent px-3 text-[13px] leading-[18px] font-normal text-gray-900 placeholder:text-gray-300 focus:outline-hidden" placeholder={placeholder} readOnly={readonly} /> diff --git a/web/app/components/workflow/nodes/_base/components/mixed-variable-text-input/placeholder.tsx b/web/app/components/workflow/nodes/_base/components/mixed-variable-text-input/placeholder.tsx index 9884e85657..4af622bb46 100644 --- a/web/app/components/workflow/nodes/_base/components/mixed-variable-text-input/placeholder.tsx +++ b/web/app/components/workflow/nodes/_base/components/mixed-variable-text-input/placeholder.tsx @@ -27,9 +27,9 @@ const Placeholder = () => { >
{t('nodes.tool.insertPlaceholder1', { ns: 'workflow' })} -
/
+
/
{ e.preventDefault() e.stopPropagation() diff --git a/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx b/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx index bc979eed60..c0a4f0b537 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx @@ -2,18 +2,17 @@ import type { CommonNodeType, OnSelectBlock, } from '@/app/components/workflow/types' -import { RiMoreFill } from '@remixicon/react' import { intersection } from 'es-toolkit/array' import { useCallback, } from 'react' import { useTranslation } from 'react-i18next' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' import { Button } from '@/app/components/base/ui/button' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuTrigger, +} from '@/app/components/base/ui/dropdown-menu' import BlockSelector from '@/app/components/workflow/block-selector' import { useAvailableBlocks, @@ -86,18 +85,21 @@ const Operator = ({ } = useNodesInteractions() return ( - - onOpenChange(!open)}> + }> - - + +
handleNodeDisconnect(nodeId)} + onClick={() => { + onOpenChange(false) + handleNodeDisconnect(nodeId) + }} > {t('common.disconnect', { ns: 'workflow' })}
@@ -115,14 +120,17 @@ const Operator = ({
handleNodeDelete(nodeId)} + onClick={() => { + onOpenChange(false) + handleNodeDelete(nodeId) + }} > {t('operation.delete', { ns: 'common' })}
- - + + ) } diff --git a/web/app/components/workflow/nodes/_base/components/node-control.tsx b/web/app/components/workflow/nodes/_base/components/node-control.tsx index aab4b5065d..547d0b1daa 100644 --- a/web/app/components/workflow/nodes/_base/components/node-control.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-control.tsx @@ -9,7 +9,11 @@ import { useTranslation } from 'react-i18next' import { Stop, } from '@/app/components/base/icons/src/vender/line/mediaAndDevices' -import Tooltip from '@/app/components/base/tooltip' +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/app/components/base/ui/tooltip' import { useWorkflowStore } from '@/app/components/workflow/store' import { useNodesInteractions, @@ -46,7 +50,8 @@ const NodeControl: FC = ({ `} >
e.stopPropagation()} onClick={e => e.stopPropagation()} > { @@ -71,11 +76,13 @@ const NodeControl: FC = ({ isSingleRunning ? : ( - - + + } + /> + + {t('panel.runThisStep', { ns: 'workflow' })} + ) } diff --git a/web/app/components/workflow/nodes/_base/components/panel-operator/index.tsx b/web/app/components/workflow/nodes/_base/components/panel-operator/index.tsx index 173a084dcf..7b3469aaba 100644 --- a/web/app/components/workflow/nodes/_base/components/panel-operator/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/panel-operator/index.tsx @@ -1,23 +1,24 @@ import type { OffsetOptions } from '@floating-ui/react' import type { Node } from '@/app/components/workflow/types' -import { RiMoreFill } from '@remixicon/react' +import { cn } from '@langgenius/dify-ui/cn' import { memo, useCallback, useState, } from 'react' +import { useTranslation } from 'react-i18next' import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' + DropdownMenu, + DropdownMenuContent, + DropdownMenuTrigger, +} from '@/app/components/base/ui/dropdown-menu' import PanelOperatorPopup from './panel-operator-popup' type PanelOperatorProps = { id: string data: Node['data'] triggerClassName?: string - offset?: OffsetOptions + offset?: OffsetOptions | number onOpenChange?: (open: boolean) => void inNode?: boolean showHelpLink?: boolean @@ -33,7 +34,16 @@ const PanelOperator = ({ onOpenChange, showHelpLink = true, }: PanelOperatorProps) => { + const { t } = useTranslation() const [open, setOpen] = useState(false) + const sideOffset = typeof offset === 'number' + ? offset + : typeof offset === 'object' && offset && 'mainAxis' in offset && typeof offset.mainAxis === 'number' + ? offset.mainAxis + : 4 + const alignOffset = typeof offset === 'object' && offset && 'crossAxis' in offset && typeof offset.crossAxis === 'number' + ? offset.crossAxis + : 0 const handleOpenChange = useCallback((newOpen: boolean) => { setOpen(newOpen) @@ -43,33 +53,35 @@ const PanelOperator = ({ }, [onOpenChange]) return ( - - handleOpenChange(!open)}> -
- -
-
- + } + aria-label={t('operation.more', { ns: 'common' })} + className={cn( + 'nodrag nopan nowheel flex h-6 w-6 cursor-pointer items-center justify-center rounded-md hover:bg-state-base-hover', + open && 'bg-state-base-hover', + triggerClassName, + )} + > + + + setOpen(false)} + onClosePopup={() => handleOpenChange(false)} showHelpLink={showHelpLink} /> - -
+ + ) } diff --git a/web/app/components/workflow/nodes/_base/components/title-description-input.tsx b/web/app/components/workflow/nodes/_base/components/title-description-input.tsx index db34c35c91..d393896bec 100644 --- a/web/app/components/workflow/nodes/_base/components/title-description-input.tsx +++ b/web/app/components/workflow/nodes/_base/components/title-description-input.tsx @@ -53,7 +53,7 @@ export const TitleInput = memo(({ onChange={handleChange} onKeyDown={handleKeyDown} className={` - system-xl-semibold mr-2 h-7 min-w-0 grow appearance-none rounded-md border border-transparent bg-transparent px-1 text-text-primary + mr-2 h-7 min-w-0 grow appearance-none rounded-md border border-transparent bg-transparent px-1 system-xl-semibold text-text-primary outline-hidden focus:shadow-xs `} placeholder={t('common.addTitle', { ns: 'workflow' }) || ''} @@ -83,8 +83,8 @@ export const DescriptionInput = memo(({ return (
diff --git a/web/app/components/workflow/nodes/_base/components/variable/constant-field.tsx b/web/app/components/workflow/nodes/_base/components/variable/constant-field.tsx index c5ff2ba98b..41be240bda 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/constant-field.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/constant-field.tsx @@ -57,7 +57,7 @@ const ConstantField: FC = ({ {schema.type === FormTypeEnum.textNumber && ( = ({ const list = outputKeyOrders.map((key) => { return { variable: key, - variable_type: outputs[key]?.type, + variable_type: outputs[key]?.type!, } }) @@ -50,15 +50,15 @@ const OutputVarList: FC = ({ const handleVarNameChange = useCallback((index: number) => { return (e: React.ChangeEvent) => { - const oldKey = list[index].variable + const oldKey = list[index]!.variable replaceSpaceWithUnderscoreInVarNameInput(e.target) const newKey = e.target.value - validateVarInput(list.toSpliced(index, 1), newKey) + validateVarInput(list.filter((_, itemIndex) => itemIndex !== index), newKey) const newOutputs = produce(outputs, (draft) => { - draft[newKey] = draft[oldKey] + draft[newKey] = draft[oldKey]! // Only delete old key if no other entry shares this name if (!list.some((item, i) => i !== index && item.variable === oldKey)) delete draft[oldKey] @@ -69,9 +69,9 @@ const OutputVarList: FC = ({ const handleVarTypeChange = useCallback((index: number) => { return (value: string) => { - const key = list[index].variable + const key = list[index]!.variable const newOutputs = produce(outputs, (draft) => { - draft[key].type = value as VarType + draft[key]!.type = value as VarType }) onChange(newOutputs) } diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx index 3caf1327c9..b8ec03eb93 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx @@ -67,11 +67,11 @@ const VarList: FC = ({ const newKey = e.target.value - validateVarInput(list.toSpliced(index, 1), newKey) + validateVarInput(list.filter((_, itemIndex) => itemIndex !== index), newKey) - onVarNameChange?.(list[index].variable, newKey) + onVarNameChange?.(list[index]!.variable, newKey) const newList = produce(list, (draft) => { - draft[index].variable = newKey + draft[index]!.variable = newKey }) onChange(newList) } @@ -81,26 +81,26 @@ const VarList: FC = ({ return (value: ValueSelector | string, varKindType: VarKindType, varInfo?: Var) => { const newList = produce(list, (draft) => { if (!isSupportConstantValue || varKindType === VarKindType.variable) { - draft[index].value_selector = value as ValueSelector - draft[index].value_type = varInfo?.type + draft[index]!.value_selector = value as ValueSelector + draft[index]!.value_type = varInfo?.type if (isSupportConstantValue) - draft[index].variable_type = VarKindType.variable + draft[index]!.variable_type = VarKindType.variable - if (!draft[index].variable) { + if (!draft[index]!.variable) { const variables = draft.map(v => v.variable) - let newVarName = value[value.length - 1] + let newVarName = value[value.length - 1]! let count = 1 - while (variables.includes(newVarName)) { + while (variables.includes(newVarName!)) { newVarName = `${value[value.length - 1]}_${count}` count++ } - draft[index].variable = newVarName + draft[index]!.variable = newVarName } } else { - draft[index].variable_type = VarKindType.constant - draft[index].value_selector = value as ValueSelector - draft[index].value = value as string + draft[index]!.variable_type = VarKindType.constant + draft[index]!.value_selector = value as ValueSelector + draft[index]!.value = value as string } }) onChange(newList) diff --git a/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts b/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts index a77af2daef..188084345e 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts @@ -51,7 +51,7 @@ function useOutputVarList({ } = useDebounceFn( (id: string, newName: string) => { const oldName = oldNameRecord.current[id] - renameInspectVarName(id, oldName, newName) + renameInspectVarName(id, oldName!, newName) delete oldNameRecord.current[id] }, { wait: 500 }, @@ -73,9 +73,9 @@ function useOutputVarList({ } if (newKey) { - handleOutVarRenameChange(id, [id, outputKeyOrders[changedIndex!]], [id, newKey]) + handleOutVarRenameChange(id, [id, outputKeyOrders[changedIndex!]!], [id, newKey]) if (!(id in oldNameRecord.current)) - oldNameRecord.current[id] = outputKeyOrders[changedIndex!] + oldNameRecord.current[id] = outputKeyOrders[changedIndex!]! renameInspectNameWithDebounce(id, newKey) } else if (changedIndex === undefined) { @@ -126,7 +126,7 @@ function useOutputVarList({ hideRemoveVarConfirm() }, [deleteInspectVar, hideRemoveVarConfirm, id, nodesWithInspectVars, removeUsedVarInNodes, removedVar]) const handleRemoveVariable = useCallback((index: number) => { - const key = outputKeyOrders[index] + const key = outputKeyOrders[index]! if (isVarUsedInNodes([id, key])) { showRemoveVarConfirm() @@ -137,15 +137,15 @@ function useOutputVarList({ const newOutputKeyOrders = outputKeyOrders.filter((_, i) => i !== index) const newInputs = produce(inputs, (draft: any) => { // Only delete from outputs when no remaining entry shares this name - if (!newOutputKeyOrders.includes(key)) - delete draft[varKey][key] + if (!newOutputKeyOrders.includes(key!)) + delete draft[varKey][key!] if ((inputs as CodeNodeType).type === BlockEnum.Code && (inputs as CodeNodeType).error_strategy === ErrorHandleTypeEnum.defaultValue && varKey === 'outputs') draft.default_value = getDefaultValue(draft as any) }) setInputs(newInputs) onOutputKeyOrdersChange(newOutputKeyOrders) - if (!newOutputKeyOrders.includes(key)) { + if (!newOutputKeyOrders.includes(key!)) { const varId = nodesWithInspectVars.find(node => node.nodeId === id)?.vars.find((varItem) => { return varItem.name === key })?.id diff --git a/web/app/components/workflow/nodes/assigner/components/__tests__/operation-selector.spec.tsx b/web/app/components/workflow/nodes/assigner/components/__tests__/operation-selector.spec.tsx index f59de9e874..1a12ab129d 100644 --- a/web/app/components/workflow/nodes/assigner/components/__tests__/operation-selector.spec.tsx +++ b/web/app/components/workflow/nodes/assigner/components/__tests__/operation-selector.spec.tsx @@ -47,7 +47,7 @@ vi.mock('@/app/components/base/ui/dropdown-menu', async () => { return open ?
{children}
: null }, DropdownMenuGroup: ({ children }: { children: React.ReactNode }) =>
{children}
, - DropdownMenuGroupLabel: ({ children }: { children: React.ReactNode }) =>
{children}
, + DropdownMenuLabel: ({ children }: { children: React.ReactNode }) =>
{children}
, DropdownMenuSeparator: () =>
, DropdownMenuItem: ({ children, diff --git a/web/app/components/workflow/nodes/assigner/components/operation-selector.tsx b/web/app/components/workflow/nodes/assigner/components/operation-selector.tsx index 333aa5b2cd..a22b8f5f3f 100644 --- a/web/app/components/workflow/nodes/assigner/components/operation-selector.tsx +++ b/web/app/components/workflow/nodes/assigner/components/operation-selector.tsx @@ -3,18 +3,14 @@ import type { WriteMode } from '../types' import type { Item } from '../utils' import type { VarType } from '@/app/components/workflow/types' import { cn } from '@langgenius/dify-ui/cn' -import { - RiArrowDownSLine, - RiCheckLine, -} from '@remixicon/react' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, - DropdownMenuGroupLabel, DropdownMenuItem, + DropdownMenuLabel, DropdownMenuSeparator, DropdownMenuTrigger, } from '@/app/components/base/ui/dropdown-menu' @@ -68,7 +64,7 @@ const OperationSelector: FC = ({ {selectedItem && isOperationItem(selectedItem) ? t(`nodes.assigner.operations.${selectedItem.name}`, { ns: 'workflow' }) : t('nodes.assigner.operations.title', { ns: 'workflow' })}
- + = ({ popupClassName={cn('w-[140px]', popupClassName)} > - {t('nodes.assigner.operations.title', { ns: 'workflow' })} + {t('nodes.assigner.operations.title', { ns: 'workflow' })} {items.map(item => ( !isOperationItem(item) ? ( @@ -94,7 +90,7 @@ const OperationSelector: FC = ({
{item.value === value && (
- +
)} diff --git a/web/app/components/workflow/nodes/http/node.tsx b/web/app/components/workflow/nodes/http/node.tsx index e40d6e2b6a..97c88da18d 100644 --- a/web/app/components/workflow/nodes/http/node.tsx +++ b/web/app/components/workflow/nodes/http/node.tsx @@ -15,8 +15,8 @@ const Node: FC> = ({ return (
-
{method}
-
+
{method}
+
{ const { t } = useTranslation() return ( -
+
diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx index 5fec9ad5b6..8505fe76f9 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx @@ -40,12 +40,12 @@ const ConditionNumber = ({ }, [onChange]) return ( -
+
-
+
{ valueMethod === 'variable' && !isCommonVariable && ( +
-
+
{ valueMethod === 'variable' && !isCommonVariable && ( { const { t } = useTranslation() return ( -
+
diff --git a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/item.tsx b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/item.tsx index 859be28614..cc254d324e 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/item.tsx +++ b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/item.tsx @@ -30,15 +30,15 @@ const Item: FC = ({
{payload.name}
-
{payload.type}
+
{payload.type}
{payload.required && ( -
{t(`${i18nPrefix}.addExtractParameterContent.required`, { ns: 'workflow' })}
+
{t(`${i18nPrefix}.addExtractParameterContent.required`, { ns: 'workflow' })}
)}
-
{payload.description}
+
{payload.description}
> = ({ return (
-
+
{/* Webhook URL Section */}
@@ -138,7 +138,7 @@ const Panel: FC> = ({
{isPrivateOrLocalAddress(inputs.webhook_debug_url) && ( -
+
{t(`${i18nPrefix}.debugUrlPrivateAddressWarning`, { ns: 'workflow' })}
)} @@ -197,7 +197,7 @@ const Panel: FC> = ({
-
-