Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-04-14 10:14:13 +08:00
commit 8747e3a2d3
99 changed files with 967 additions and 407 deletions

View File

@ -6,14 +6,7 @@ on:
- "main"
paths:
- api/Dockerfile
- web/docker/**
- web/Dockerfile
- packages/**
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
- .nvmrc
concurrency:
group: docker-build-${{ github.head_ref || github.run_id }}

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
class AdvancedPromptTemplateQuery(BaseModel):
@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource):
@account_initialization_required
def get(self):
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
prompt_args: AdvancedPromptTemplateArgs = {
"app_mode": args.app_mode,
"model_mode": args.model_mode,
"model_name": args.model_name,
"has_context": args.has_context,
}
return AdvancedPromptTemplateService.get_prompt(prompt_args)

View File

@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
session=session,
app_model=app_model,

View File

@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunService
from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService
def _build_backstage_input_url(form_token: str | None) -> str | None:
@ -214,7 +214,11 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
Get advanced chat app workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
# Default to DEBUGGING if not specified
triggered_from = (
@ -356,7 +360,11 @@ class WorkflowRunListApi(Resource):
Get workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (

View File

@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource):
node_id = args.node_id
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Get webhook trigger for this app and node
webhook_trigger = session.scalar(
select(WorkflowWebhookTrigger)
@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Get all triggers for this app using select API
triggers = (
session.execute(

View File

@ -1,7 +1,10 @@
import logging
import flask_login
from flask import make_response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
@ -42,12 +45,13 @@ from libs.token import (
)
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginPayloadBase
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
logger = logging.getLogger(__name__)
class LoginPayload(LoginPayloadBase):
@ -91,10 +95,12 @@ class LoginApi(Resource):
normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED)
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
@ -110,14 +116,20 @@ class LoginApi(Resource):
invitee_email = data.get("email") if data else None
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email:
_log_console_login_failure(
email=normalized_email,
reason=LoginFailureReason.INVALID_INVITATION_EMAIL,
)
raise InvalidEmailError()
account = _authenticate_account_with_case_fallback(
request_email, normalized_email, args.password, invite_token
)
except services.errors.account.AccountLoginError:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError()
except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(normalized_email)
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
@ -240,20 +252,27 @@ class EmailCodeLoginApi(Resource):
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
raise InvalidTokenError()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
if token_data["code"] != args.code:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args.token)
try:
account = _get_account_with_case_fallback(original_email)
except Unauthorized as exc:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() from exc
except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
@ -279,6 +298,7 @@ class EmailCodeLoginApi(Resource):
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
@ -336,3 +356,12 @@ def _authenticate_account_with_case_fallback(
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
logger.warning(
"Console login failed: email=%s reason=%s ip_address=%s",
email,
reason,
extract_remote_ip(request),
)

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from typing import TypedDict
from flask import request
@ -13,6 +14,14 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US"
class NotificationLangContent(TypedDict, total=False):
lang: str
title: str
subtitle: str
body: str
titlePicUrl: str
class NotificationItemDict(TypedDict):
notification_id: str | None
frequency: str | None
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
notifications: list[NotificationItemDict]
def _pick_lang_content(contents: dict, lang: str) -> dict:
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
"""Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
return (
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
)
class DismissNotificationPayload(BaseModel):
@ -71,7 +82,7 @@ class NotificationApi(Resource):
notifications: list[NotificationItemDict] = []
for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {}
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang)
item: NotificationItemDict = {
"notification_id": notification.get("notificationId"),

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
from typing import Any, Literal
import pytz
from flask import request
@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict:
def _serialize_account(account) -> dict[str, Any]:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")

View File

@ -20,7 +20,7 @@ from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService
from services.operation_service import OperationService, UtmInfo
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from flask import Response
from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity
from graphon.variables.input_entities import VariableEntity, VariableEntityType
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
"""Convert raw user input form to VariableEntity objects"""
return [self._create_variable_entity(item) for item in raw_form]
def _create_variable_entity(self, item: dict) -> VariableEntity:
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
"""Create a single VariableEntity from raw form item"""
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
try:
variable_type = VariableEntityType(variable_type_raw)
except ValueError as e:
raise MCPRequestError(
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
) from e
variable = item[variable_type_raw]
return VariableEntity(
type=variable_type,
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
json_schema=variable.get("json_schema"),
)
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
"""Parse and validate MCP request"""
try:
return mcp_types.ClientRequest.model_validate(args)

View File

@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
"""Marshal a single segment and enrich it with summary content."""
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict = {}
summaries: dict[str, str | None] = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result = []
result: list[dict[str, Any]] = []
for segment in segments:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result

View File

@ -5,6 +5,7 @@ Web App Human Input Form APIs.
import json
import logging
from datetime import datetime
from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
@ -58,10 +59,19 @@ def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
class FormDefinitionPayload(TypedDict):
form_content: Any
inputs: Any
resolved_default_values: dict[str, str]
user_actions: Any
expiration_time: int
site: NotRequired[dict]
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
"""Return the form payload (optionally with site) as a JSON response."""
definition_payload = form.get_definition().model_dump()
payload = {
payload: FormDefinitionPayload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),

View File

@ -1,7 +1,10 @@
import logging
from flask import make_response, request
from flask_restx import Resource
from jwt import InvalidTokenError
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
@ -20,7 +23,7 @@ from controllers.console.wraps import (
)
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
from libs.helper import EmailStr
from libs.helper import EmailStr, extract_remote_ip
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
@ -29,9 +32,11 @@ from libs.token import (
)
from services.account_service import AccountService
from services.app_service import AppService
from services.entities.auth_entities import LoginPayloadBase
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
from services.webapp_auth_service import WebAppAuthService
logger = logging.getLogger(__name__)
class LoginPayload(LoginPayloadBase):
@field_validator("password")
@ -76,14 +81,18 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
payload = LoginPayload.model_validate(web_ns.payload or {})
normalized_email = payload.email.lower()
try:
account = WebAppAuthService.authenticate(payload.email, payload.password)
except services.errors.account.AccountLoginError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError()
except services.errors.account.AccountNotFoundError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
@ -212,21 +221,30 @@ class EmailCodeLoginApi(Resource):
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
if token_data is None:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
raise InvalidTokenError()
token_email = token_data.get("email")
if not isinstance(token_email, str):
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if normalized_token_email != user_email:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
if token_data["code"] != payload.code:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(payload.token)
account = WebAppAuthService.get_user_through_email(token_email)
try:
account = WebAppAuthService.get_user_through_email(token_email)
except Unauthorized as exc:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() from exc
if not account:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
@ -234,3 +252,12 @@ class EmailCodeLoginApi(Resource):
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response
def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None:
logger.warning(
"Web login failed: email=%s reason=%s ip_address=%s",
email,
reason,
extract_remote_ip(request),
)

View File

@ -1,5 +1,6 @@
import uuid
from datetime import UTC, datetime, timedelta
from typing import Any
from flask import make_response, request
from flask_restx import Resource
@ -103,21 +104,23 @@ class PassportResource(Resource):
return response
def decode_enterprise_webapp_user_id(jwt_token: str | None):
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
"""
Decode the enterprise user session from the Authorization header.
"""
if not jwt_token:
return None
decoded = PassportService().verify(jwt_token)
decoded: dict[str, Any] = PassportService().verify(jwt_token)
source = decoded.get("token_source")
if not source or source != "webapp_login_token":
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
def exchange_token_for_existing_web_user(
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
):
"""
Exchange a token for an existing web user session.
"""

View File

@ -1,4 +1,4 @@
from typing import cast
from typing import Any, cast
from flask_restx import fields, marshal, marshal_with
from sqlalchemy import select
@ -113,12 +113,12 @@ class AppSiteInfo:
}
def serialize_site(site: Site) -> dict:
def serialize_site(site: Site) -> dict[str, Any]:
"""Serialize Site model using the same schema as AppSiteApi."""
return cast(dict, marshal(site, AppSiteApi.site_fields))
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))

View File

@ -138,7 +138,9 @@ class DatasetConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for dataset feature
@ -172,7 +174,7 @@ class DatasetConfigManager:
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
@classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
"""
Extract dataset config for legacy compatibility

View File

@ -108,7 +108,7 @@ class ModelConfigManager:
return dict(config), ["model"]
@classmethod
def validate_model_completion_params(cls, cp: dict):
def validate_model_completion_params(cls, cp: dict[str, Any]):
# model.completion_params
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")

View File

@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate pre_prompt and set defaults for prompt feature
depending on the config['model']
@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
@classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict):
def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
"""
Validate post_prompt and set defaults for prompt feature

View File

@ -1,5 +1,5 @@
import re
from typing import cast
from typing import Any, cast
from graphon.variables.input_entities import VariableEntity, VariableEntityType
@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
return variable_entities, external_data_variables
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for user input form
@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
return config, related_config_keys
@classmethod
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for user input form
@ -164,7 +164,9 @@ class BasicVariablesConfigManager:
return config, ["user_input_form"]
@classmethod
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_external_data_tools_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for external data fetch feature

View File

@ -30,7 +30,7 @@ class FileUploadConfigManager:
return FileUploadConfig.model_validate(file_upload_dict)
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for file upload feature

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, ValidationError
@ -13,7 +15,7 @@ class AppConfigModel(BaseModel):
class MoreLikeThisConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -23,7 +25,7 @@ class MoreLikeThisConfigManager:
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError:

View File

@ -1,6 +1,9 @@
from typing import Any
class OpeningStatementConfigManager:
@classmethod
def convert(cls, config: dict) -> tuple[str, list]:
def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class OpeningStatementConfigManager:
return opening_statement, suggested_questions_list
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for opening statement feature

View File

@ -1,6 +1,9 @@
from typing import Any
class RetrievalResourceConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
show_retrieve_source = False
retriever_resource_dict = config.get("retriever_resource")
if retriever_resource_dict:
@ -10,7 +13,7 @@ class RetrievalResourceConfigManager:
return show_retrieve_source
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for retriever resource feature

View File

@ -1,6 +1,9 @@
from typing import Any
class SpeechToTextConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class SpeechToTextConfigManager:
return speech_to_text
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for speech to text feature

View File

@ -1,6 +1,9 @@
from typing import Any
class SuggestedQuestionsAfterAnswerConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
return suggested_questions_after_answer
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for suggested questions feature

View File

@ -1,9 +1,11 @@
from typing import Any
from core.app.app_config.entities import TextToSpeechEntity
class TextToSpeechConfigManager:
@classmethod
def convert(cls, config: dict):
def convert(cls, config: dict[str, Any]):
"""
Convert model config to model config
@ -22,7 +24,7 @@ class TextToSpeechConfigManager:
return text_to_speech
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for text to speech feature

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -1,3 +1,5 @@
from typing import Any
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
@ -34,7 +36,9 @@ class PipelineConfigManager(BaseAppConfigManager):
return pipeline_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
def config_validate(
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
) -> dict[str, Any]:
"""
Validate for pipeline config

View File

@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator):
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: dict | None = None,
next_page_parameters: dict[str, Any] | None = None,
):
"""
Get files in a folder.

View File

@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: dict | None = None
extras: dict[str, Any] | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: dict | None = None
extras: dict[str, Any] | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus

View File

@ -129,7 +129,7 @@ class DatasourceEntity(BaseModel):
identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The label of the datasource")
output_schema: dict | None = None
output_schema: dict[str, Any] | None = None
@field_validator("parameters", mode="before")
@classmethod
@ -192,7 +192,7 @@ class DatasourceInvokeMeta(BaseModel):
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: str | None = None
tool_config: dict | None = None
tool_config: dict[str, Any] | None = None
@classmethod
def empty(cls) -> DatasourceInvokeMeta:
@ -242,7 +242,7 @@ class OnlineDocumentPage(BaseModel):
page_id: str = Field(..., description="The page id")
page_name: str = Field(..., description="The page title")
page_icon: dict | None = Field(None, description="The page icon")
page_icon: dict[str, Any] | None = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: str | None = Field(None, description="The parent page id")
@ -301,7 +301,7 @@ class GetWebsiteCrawlRequest(BaseModel):
Get website crawl request
"""
crawl_parameters: dict = Field(..., description="The crawl parameters")
crawl_parameters: dict[str, Any] = Field(..., description="The crawl parameters")
class WebSiteInfoDetail(BaseModel):
@ -358,7 +358,7 @@ class OnlineDriveFileBucket(BaseModel):
bucket: str | None = Field(None, description="The file bucket")
files: list[OnlineDriveFile] = Field(..., description="The file list")
is_truncated: bool = Field(False, description="Whether the result is truncated")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesRequest(BaseModel):
@ -369,7 +369,7 @@ class OnlineDriveBrowseFilesRequest(BaseModel):
bucket: str | None = Field(None, description="The file bucket")
prefix: str = Field(..., description="The parent folder ID")
max_keys: int = Field(20, description="Page size for pagination")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesResponse(BaseModel):

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator
@ -37,7 +39,7 @@ class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: dict | None = None
data_source_info: dict[str, Any] | None = None
name: str
indexing_status: str
error: str | None = None

View File

@ -6,6 +6,7 @@ import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Any
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel):
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
"""
Get current credentials.
@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel):
return session.execute(stmt).scalar_one_or_none()
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
"""
Get provider credentials.
@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
"""
Validate custom credentials.
:param credentials: provider credentials
@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel):
provider_names.append(model_provider_id.provider_name)
return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None):
def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
"""
Add custom provider credentials.
:param credentials: provider credentials
@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel):
def update_provider_credential(
self,
credentials: dict,
credentials: dict[str, Any],
credential_id: str,
credential_name: str | None,
):
@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel):
def _get_specific_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str
) -> dict | None:
) -> dict[str, Any] | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
def get_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str | None
) -> dict[str, Any] | None:
"""
Get custom model credentials.
@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel):
self,
model_type: ModelType,
model: str,
credentials: dict,
credentials: dict[str, Any],
credential_id: str = "",
session: Session | None = None,
):
@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
) -> None:
"""
Create a custom model credential.
@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel):
raise
def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
self,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
credential_name: str | None,
credential_id: str,
) -> None:
"""
Update a custom model credential.
@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel):
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
) -> AIModelEntity | None:
"""
Get model schema
"""
@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
"""
Obfuscated credentials.

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from enum import StrEnum, auto
from typing import Union
from typing import Any, Union
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import BaseModel, ConfigDict, Field
@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
enabled: bool
current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = []
credentials: dict | None = None
credentials: dict[str, Any] | None = None
class CustomProviderConfiguration(BaseModel):
@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
Model class for provider custom configuration.
"""
credentials: dict
credentials: dict[str, Any]
current_credential_id: str | None = None
current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] = []
@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
model: str
model_type: ModelType
credentials: dict | None
credentials: dict[str, Any] | None
current_credential_id: str | None = None
current_credential_name: str | None = None
available_model_credentials: list[CredentialConfiguration] = []

View File

@ -1,6 +1,7 @@
import json
from enum import StrEnum
from json import JSONDecodeError
from typing import Any
from extensions.ext_redis import redis_client
@ -15,7 +16,7 @@ class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> dict | None:
def get(self) -> dict[str, Any] | None:
"""
Get cached model provider credentials.
@ -33,7 +34,7 @@ class ProviderCredentialsCache:
else:
return None
def set(self, credentials: dict):
def set(self, credentials: dict[str, Any]):
"""
Cache model provider credentials.

View File

@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC):
"""Generate cache key based on subclass implementation"""
pass
def get(self) -> dict | None:
def get(self) -> dict[str, Any] | None:
"""Get cached provider credentials"""
cached_credentials = redis_client.get(self.cache_key)
if cached_credentials:
@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
class NoOpProviderCredentialCache:
"""No-op provider credential cache"""
def get(self) -> dict | None:
def get(self) -> dict[str, Any] | None:
"""Get cached provider credentials"""
return None

View File

@ -1,6 +1,7 @@
import json
from enum import StrEnum
from json import JSONDecodeError
from typing import Any
from extensions.ext_redis import redis_client
@ -18,7 +19,7 @@ class ToolParameterCache:
f":identity_id:{identity_id}"
)
def get(self) -> dict | None:
def get(self) -> dict[str, Any] | None:
"""
Get cached model provider credentials.
@ -36,7 +37,7 @@ class ToolParameterCache:
else:
return None
def set(self, parameters: dict):
def set(self, parameters: dict[str, Any]):
"""Cache model provider credentials."""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))

View File

@ -115,7 +115,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True] = True,
@ -126,7 +126,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False] = False,
@ -137,7 +137,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
@ -147,7 +147,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
@ -528,7 +528,7 @@ class LBModelManager:
model_type: ModelType,
model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: dict | None = None,
managed_credentials: dict[str, Any] | None = None,
):
"""
Load balancing model manager

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from sqlalchemy import select
@ -10,7 +12,7 @@ from models.api_based_extension import APIBasedExtension
class ModerationInputParams(BaseModel):
app_id: str = ""
inputs: dict = Field(default_factory=dict)
inputs: dict[str, Any] = Field(default_factory=dict)
query: str = ""
@ -23,7 +25,7 @@ class ApiModeration(Moderation):
name: str = "api"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -41,7 +43,7 @@ class ApiModeration(Moderation):
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -73,7 +75,7 @@ class ApiModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]):
if self.config is None:
raise ValueError("The config is not set.")
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
@ -15,7 +16,7 @@ class ModerationInputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
inputs: dict = Field(default_factory=dict)
inputs: dict[str, Any] = Field(default_factory=dict)
query: str = ""
@ -33,13 +34,13 @@ class Moderation(Extensible, ABC):
module: ExtensionModule = ExtensionModule.MODERATION
def __init__(self, app_id: str, tenant_id: str, config: dict | None = None):
def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None):
super().__init__(tenant_id, config)
self.app_id = app_id
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None:
"""
Validate the incoming form config data.
@ -50,7 +51,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@abstractmethod
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review
@ -75,7 +76,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@classmethod
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool):
def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool):
# inputs_config
inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict):

View File

@ -1,3 +1,5 @@
from typing import Any
from core.extension.extensible import ExtensionModule
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
from extensions.ext_code_based_extension import code_based_extension
@ -6,12 +8,12 @@ from extensions.ext_code_based_extension import code_based_extension
class ModerationFactory:
__extension_instance: Moderation
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict):
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]):
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
self.__extension_instance = extension_class(app_id, tenant_id, config)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict):
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -24,7 +26,7 @@ class ModerationFactory:
# FIXME: mypy error, try to fix it instead of using type: ignore
extension_class.validate_config(tenant_id, config) # type: ignore
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review

View File

@ -8,7 +8,7 @@ class KeywordsModeration(Moderation):
name: str = "keywords"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -28,7 +28,7 @@ class KeywordsModeration(Moderation):
if len(keywords_row_len) > 100:
raise ValueError("the number of rows for the keywords must be less than 100")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -66,7 +66,7 @@ class KeywordsModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool:
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:

View File

@ -1,3 +1,5 @@
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType
from core.model_manager import ModelManager
@ -8,7 +10,7 @@ class OpenAIModeration(Moderation):
name: str = "openai_moderation"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -18,7 +20,7 @@ class OpenAIModeration(Moderation):
"""
cls._validate_inputs_and_outputs_config(config, True)
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -49,7 +51,7 @@ class OpenAIModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict):
def _is_violated(self, inputs: dict[str, Any]):
text = "\n".join(str(inputs.values()))
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
model_instance = model_manager.get_model_instance(

View File

@ -778,7 +778,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
def _construct_llm_attributes(self, prompts: dict[str, Any] | list[Any] | str | None) -> dict[str, str]:
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
attributes: dict[str, str] = {}
@ -797,7 +797,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
set_attribute(path, value)
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
def set_tool_call_attributes(
message_index: int, tool_index: int, tool_call: dict[str, Any] | object | None
) -> None:
"""Extract and assign tool call details safely."""
if not tool_call:
return

View File

@ -242,7 +242,7 @@ class MLflowDataTrace(BaseTraceInstance):
return inputs, attributes
def _parse_knowledge_retrieval_outputs(self, outputs: dict):
def _parse_knowledge_retrieval_outputs(self, outputs: dict[str, Any]):
"""Parse KR outputs and attributes from KR workflow node"""
retrieved = outputs.get("result", [])
@ -319,7 +319,7 @@ class MLflowDataTrace(BaseTraceInstance):
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def _get_message_user_id(self, metadata: dict) -> str | None:
def _get_message_user_id(self, metadata: dict[str, Any]) -> str | None:
if (end_user_id := metadata.get("from_end_user_id")) and (
end_user_data := db.session.get(EndUser, end_user_id)
):
@ -468,7 +468,7 @@ class MLflowDataTrace(BaseTraceInstance):
}
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
def _set_trace_metadata(self, span: Span, metadata: dict):
def _set_trace_metadata(self, span: Span, metadata: dict[str, Any]):
token = None
try:
# NB: Set span in context such that we can use update_current_trace() API
@ -490,7 +490,7 @@ class MLflowDataTrace(BaseTraceInstance):
return messages
return prompts # Fallback to original format
def _parse_single_message(self, item: dict):
def _parse_single_message(self, item: dict[str, Any]):
"""Postprocess single message format to be standard chat message"""
role = item.get("role", "user")
msg = {"role": role, "content": item.get("text", "")}

View File

@ -3,7 +3,7 @@ import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import cast
from typing import Any, cast
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from opik import Opik, Trace
@ -436,7 +436,7 @@ class OpikDataTrace(BaseTraceInstance):
self.add_span(span_data)
def add_trace(self, opik_trace_data: dict) -> Trace:
def add_trace(self, opik_trace_data: dict[str, Any]) -> Trace:
try:
trace = self.opik_client.trace(**opik_trace_data)
logger.debug("Opik Trace created successfully")
@ -444,7 +444,7 @@ class OpikDataTrace(BaseTraceInstance):
except Exception as e:
raise ValueError(f"Opik Failed to create trace: {str(e)}")
def add_span(self, opik_span_data: dict):
def add_span(self, opik_span_data: dict[str, Any]):
try:
self.opik_client.span(**opik_span_data)
logger.debug("Opik Span created successfully")

View File

@ -324,7 +324,7 @@ class OpsTraceManager:
@classmethod
def encrypt_tracing_config(
cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any], current_trace_config=None
):
"""
Encrypt tracing config.
@ -363,7 +363,7 @@ class OpsTraceManager:
return encrypted_config.model_dump()
@classmethod
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any]):
"""
Decrypt tracing config
:param tenant_id: tenant id
@ -408,7 +408,7 @@ class OpsTraceManager:
return dict(decrypted_config)
@classmethod
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict[str, Any]):
"""
Decrypt tracing config
:param tracing_provider: tracing provider
@ -581,7 +581,7 @@ class OpsTraceManager:
return app_trace_config
@staticmethod
def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
def check_trace_config_is_effective(tracing_config: dict[str, Any], tracing_provider: str):
"""
Check trace config is effective
:param tracing_config: tracing config
@ -596,7 +596,7 @@ class OpsTraceManager:
return trace_instance(config).api_check()
@staticmethod
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
def get_trace_config_project_key(tracing_config: dict[str, Any], tracing_provider: str):
"""
get trace config is project key
:param tracing_config: tracing config
@ -611,7 +611,7 @@ class OpsTraceManager:
return trace_instance(config).get_project_key()
@staticmethod
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
def get_trace_config_project_url(tracing_config: dict[str, Any], tracing_provider: str):
"""
get trace config is project key
:param tracing_config: tracing config
@ -1322,8 +1322,8 @@ class TraceTask:
error=error,
)
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
node_data: dict = kwargs.get("node_execution_data", {})
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict[str, Any]:
node_data: dict[str, Any] = kwargs.get("node_execution_data", {})
if not node_data:
return {}
@ -1431,7 +1431,7 @@ class TraceTask:
return node_trace
return DraftNodeExecutionTrace(**node_trace.model_dump())
def _extract_streaming_metrics(self, message_data) -> dict:
def _extract_streaming_metrics(self, message_data) -> dict[str, Any]:
if not message_data.message_metadata:
return {}

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field, model_validator
@ -31,7 +32,7 @@ class EndpointEntity(BasePluginEntity):
entity of an endpoint
"""
settings: dict
settings: dict[str, Any]
tenant_id: str
plugin_id: str
expired_at: datetime

View File

@ -1,3 +1,5 @@
from typing import Any
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from pydantic import BaseModel, Field, computed_field, model_validator
@ -40,7 +42,7 @@ class MarketplacePluginDeclaration(BaseModel):
@model_validator(mode="before")
@classmethod
def transform_declaration(cls, data: dict):
def transform_declaration(cls, data: dict[str, Any]) -> dict[str, Any]:
if "endpoint" in data and not data["endpoint"]:
del data["endpoint"]
if "model" in data and not data["model"]:

View File

@ -123,7 +123,7 @@ class PluginDeclaration(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_category(cls, values: dict):
def validate_category(cls, values: dict[str, Any]) -> dict[str, Any]:
# auto detect category
if values.get("tool"):
values["category"] = PluginCategory.Tool

View File

@ -73,7 +73,7 @@ class PluginBasicBooleanResponse(BaseModel):
"""
result: bool
credentials: dict | None = None
credentials: dict[str, Any] | None = None
class PluginModelSchemaEntity(BaseModel):

View File

@ -49,7 +49,7 @@ class RequestInvokeTool(BaseModel):
tool_type: Literal["builtin", "workflow", "api", "mcp"]
provider: str
tool: str
tool_parameters: dict
tool_parameters: dict[str, Any]
credential_id: str | None = None
@ -209,7 +209,7 @@ class RequestInvokeEncrypt(BaseModel):
opt: Literal["encrypt", "decrypt", "clear"]
namespace: Literal["endpoint"]
identity: str
data: dict = Field(default_factory=dict)
data: dict[str, Any] = Field(default_factory=dict)
config: list[BasicProviderConfig] = Field(default_factory=list)

View File

@ -26,7 +26,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
@ -68,7 +68,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
@ -110,7 +110,7 @@ class PluginDatasourceManager(BasePluginClient):
tool_provider_id = DatasourceProviderID(provider_id)
def transformer(json_response: dict[str, Any]) -> dict:
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
data = json_response.get("data")
if data:
for datasource in data.get("declaration", {}).get("datasources", []):

View File

@ -1,3 +1,5 @@
from typing import Any
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
from core.plugin.impl.exc import PluginDaemonInternalServerError
@ -5,7 +7,12 @@ from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
def create_endpoint(
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
self,
tenant_id: str,
user_id: str,
plugin_unique_identifier: str,
name: str,
settings: dict[str, Any],
) -> bool:
"""
Create an endpoint for the given plugin.
@ -49,7 +56,9 @@ class PluginEndpointClient(BasePluginClient):
params={"plugin_id": plugin_id, "page": page, "page_size": page_size},
)
def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
def update_endpoint(
self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]
) -> bool:
"""
Update the settings of the given endpoint.
"""

View File

@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
) -> AIModelEntity | None:
"""
Get model schema
@ -80,7 +80,7 @@ class PluginModelClient(BasePluginClient):
return None
def validate_provider_credentials(
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict[str, Any]
) -> bool:
"""
validate the credentials of the provider
@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
) -> bool:
"""
validate the credentials of the provider
@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
texts: list[str],
input_type: str,
) -> EmbeddingResult:
@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
documents: list[dict],
input_type: str,
) -> EmbeddingResult:
@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
texts: list[str],
) -> list[int]:
"""
@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
query: str,
docs: list[str],
score_threshold: float | None = None,
@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
content_text: str,
voice: str,
) -> Generator[bytes, None, None]:
@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
language: str | None = None,
):
"""
@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
file: IO[bytes],
) -> str:
"""
@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
text: str,
) -> bool:
"""

View File

@ -1,4 +1,5 @@
from collections.abc import Sequence
from typing import Any
from requests import HTTPError
@ -263,7 +264,7 @@ class PluginInstaller(BasePluginClient):
original_plugin_unique_identifier: str,
new_plugin_unique_identifier: str,
source: PluginInstallationSource,
meta: dict,
meta: dict[str, Any],
) -> PluginInstallTaskStartResponse:
"""
Upgrade a plugin.

View File

@ -1,6 +1,7 @@
"""Abstract interface for document loader implementations."""
import csv
from typing import Any
import pandas as pd
@ -23,7 +24,7 @@ class CSVExtractor(BaseExtractor):
encoding: str | None = None,
autodetect_encoding: bool = False,
source_column: str | None = None,
csv_args: dict | None = None,
csv_args: dict[str, Any] | None = None,
):
"""Initialize with file path."""
self._file_path = file_path

View File

@ -54,8 +54,8 @@ class BaseAPIClient:
self,
method: str,
endpoint: str,
query_params: dict | None = None,
data: dict | None = None,
query_params: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
**kwargs,
) -> Response:
stream = kwargs.pop("stream", False)
@ -66,19 +66,25 @@ class BaseAPIClient:
return self.session.request(method, url, params=query_params, json=data, **kwargs)
def _get(self, endpoint: str, query_params: dict | None = None, **kwargs):
def _get(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs):
return self._request("GET", endpoint, query_params=query_params, **kwargs)
def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
def _post(
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
):
return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs)
def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
def _put(
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
):
return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs)
def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs):
def _delete(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs):
return self._request("DELETE", endpoint, query_params=query_params, **kwargs)
def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
def _patch(
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
):
return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs)
@ -99,7 +105,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
finally:
response.close()
def process_response(self, response: Response) -> dict | bytes | list | None | Generator:
def process_response(self, response: Response) -> dict[str, Any] | bytes | list[Any] | None | Generator:
if response.status_code == 401:
raise WaterCrawlAuthenticationError(response)
@ -186,7 +192,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
yield from generator
def get_crawl_request_results(
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict | None = None
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict[str, Any] | None = None
):
query_params = query_params or {}
query_params.update({"page": page or 1, "page_size": page_size or 25})
@ -210,7 +216,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
if event_data["type"] == "result":
return event_data["data"]
def download_result(self, result_object: dict):
def download_result(self, result_object: dict[str, Any]):
response = httpx.get(result_object["result"], timeout=None)
try:
response.raise_for_status()

View File

@ -120,7 +120,7 @@ class WaterCrawlProvider:
}
def _get_results(
self, crawl_request_id: str, query_params: dict | None = None
self, crawl_request_id: str, query_params: dict[str, Any] | None = None
) -> Generator[WatercrawlDocumentData, None, None]:
page = 0
page_size = 100

View File

@ -875,7 +875,11 @@ class DatasetRetrieval:
return retrieval_resource_list
def _on_retrieval_end(
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
self,
flask_app: Flask,
documents: list[Document],
message_id: str | None = None,
timer: dict[str, Any] | None = None,
):
"""Handle retrieval end."""
with flask_app.app_context():
@ -980,7 +984,7 @@ class DatasetRetrieval:
self._send_trace_task(message_id, documents, timer)
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict[str, Any] | None):
"""Send trace task if trace manager is available."""
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
@ -1142,7 +1146,7 @@ class DatasetRetrieval:
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
inputs: dict[str, Any],
) -> list[DatasetRetrieverBaseTool] | None:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
@ -1337,7 +1341,7 @@ class DatasetRetrieval:
metadata_filtering_mode: str,
metadata_model_config: ModelConfig,
metadata_filtering_conditions: MetadataFilteringCondition | None,
inputs: dict,
inputs: dict[str, Any],
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
document_query = select(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids),
@ -1417,7 +1421,7 @@ class DatasetRetrieval:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
def _replace_metadata_filter_value(self, text: str, inputs: dict[str, Any]) -> str:
if not inputs:
return text

View File

@ -4,6 +4,7 @@ from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Any
from uuid import UUID
import numpy as np
@ -50,7 +51,7 @@ def safe_json_value(v):
return v
def safe_json_dict(d: dict):
def safe_json_dict(d: dict[str, Any]):
if not isinstance(d, dict):
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
return {k: safe_json_value(v) for k, v in d.items()}
@ -196,11 +197,11 @@ class ToolFileMessageTransformer:
@staticmethod
def _with_tool_file_meta(
meta: dict | None,
meta: dict[str, Any] | None,
*,
tool_file_id: str | None = None,
url: str | None = None,
) -> dict:
) -> dict[str, Any]:
normalized_meta = meta.copy() if meta is not None else {}
resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url)
if resolved_tool_file_id and "tool_file_id" not in normalized_meta:

View File

@ -32,7 +32,7 @@ class OpenAPISpecDict(TypedDict):
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None
openapi: Mapping[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
@ -236,7 +236,7 @@ class ApiBasedToolSchemaParser:
return value
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
def _get_tool_parameter_type(parameter: dict[str, Any]) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {}
typ: str | None = None
if parameter.get("format") == "binary":
@ -265,7 +265,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
yaml: str, extra_info: dict | None = None, warning: dict | None = None
yaml: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
@ -278,14 +278,14 @@ class ApiBasedToolSchemaParser:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = safe_load(yaml)
openapi: dict[str, Any] = safe_load(yaml)
if openapi is None:
raise ToolApiSchemaError("Invalid openapi yaml.")
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
swagger: dict[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> OpenAPISpecDict:
warning = warning or {}
"""
@ -351,7 +351,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(
json: str, extra_info: dict | None = None, warning: dict | None = None
json: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
@ -392,7 +392,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
content: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
auto parse to tool bundle

View File

@ -1,4 +1,5 @@
import copy
from typing import Any, TypedDict
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
@ -15,9 +16,18 @@ from core.prompt.prompt_templates.advanced_prompt_templates import (
from models.model import AppMode
class AdvancedPromptTemplateArgs(TypedDict):
"""Expected shape of the args dict passed to AdvancedPromptTemplateService.get_prompt."""
app_mode: str
model_mode: str
model_name: str
has_context: str
class AdvancedPromptTemplateService:
@classmethod
def get_prompt(cls, args: dict):
def get_prompt(cls, args: AdvancedPromptTemplateArgs) -> dict[str, Any]:
app_mode = args["app_mode"]
model_mode = args["model_mode"]
model_name = args["model_name"]
@ -29,7 +39,7 @@ class AdvancedPromptTemplateService:
return cls.get_common_prompt(app_mode, model_mode, has_context)
@classmethod
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]:
context_prompt = copy.deepcopy(CONTEXT)
match app_mode:
@ -63,7 +73,7 @@ class AdvancedPromptTemplateService:
return {}
@classmethod
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str):
def get_completion_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]:
if has_context == "true":
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
@ -72,7 +82,7 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str):
def get_chat_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]:
if has_context == "true":
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
@ -81,7 +91,7 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]:
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
match app_mode:

View File

@ -233,7 +233,7 @@ class DatasetService:
embedding_model_provider: str | None = None,
embedding_model_name: str | None = None,
retrieval_model: RetrievalModel | None = None,
summary_index_setting: dict | None = None,
summary_index_setting: dict[str, Any] | None = None,
):
# check if dataset name already exists
if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)):
@ -2493,7 +2493,7 @@ class DocumentService:
data_source_type: str,
document_form: str,
document_language: str,
data_source_info: dict,
data_source_info: dict[str, Any],
created_from: str,
position: int,
account: Account,
@ -2850,7 +2850,7 @@ class DocumentService:
raise ValueError("Process rule segmentation max_tokens is invalid")
@classmethod
def estimate_args_validate(cls, args: dict):
def estimate_args_validate(cls, args: dict[str, Any]):
if "info_list" not in args or not args["info_list"]:
raise ValueError("Data source info is required")
@ -3132,7 +3132,7 @@ class DocumentService:
class SegmentService:
@classmethod
def segment_create_args_validate(cls, args: dict, document: Document):
def segment_create_args_validate(cls, args: dict[str, Any], document: Document):
if document.doc_form == IndexStructureType.QA_INDEX:
if "answer" not in args or not args["answer"]:
raise ValueError("Answer is required")
@ -3149,7 +3149,7 @@ class SegmentService:
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
def create_segment(cls, args: dict[str, Any], document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None

View File

@ -1,9 +1,25 @@
from enum import StrEnum, auto
from pydantic import BaseModel, Field, field_validator
from libs.helper import EmailStr
from libs.password import valid_password
class LoginFailureReason(StrEnum):
"""Bounded reason codes for failed login audit logs."""
ACCOUNT_BANNED = auto()
ACCOUNT_IN_FREEZE = auto()
ACCOUNT_NOT_FOUND = auto()
EMAIL_CODE_EMAIL_MISMATCH = auto()
INVALID_CREDENTIALS = auto()
INVALID_EMAIL_CODE = auto()
INVALID_EMAIL_CODE_TOKEN = auto()
INVALID_INVITATION_EMAIL = auto()
LOGIN_RATE_LIMITED = auto()
class LoginPayloadBase(BaseModel):
email: EmailStr
password: str

View File

@ -47,7 +47,7 @@ class ExternalDatasetService:
return external_knowledge_apis.items, external_knowledge_apis.total
@classmethod
def validate_api_list(cls, api_settings: dict):
def validate_api_list(cls, api_settings: dict[str, Any]):
if not api_settings:
raise ValueError("api list is empty")
if not api_settings.get("endpoint"):
@ -56,7 +56,7 @@ class ExternalDatasetService:
raise ValueError("api_key is required")
@staticmethod
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis:
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict[str, Any]) -> ExternalKnowledgeApis:
settings = args.get("settings")
if settings is None:
raise ValueError("settings is required")
@ -75,7 +75,7 @@ class ExternalDatasetService:
return external_knowledge_api
@staticmethod
def check_endpoint_and_api_key(settings: dict):
def check_endpoint_and_api_key(settings: dict[str, Any]):
if "endpoint" not in settings or not settings["endpoint"]:
raise ValueError("endpoint is required")
if "api_key" not in settings or not settings["api_key"]:
@ -178,7 +178,9 @@ class ExternalDatasetService:
return external_knowledge_binding
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
def document_create_args_validate(
tenant_id: str, external_knowledge_api_id: str, process_parameter: dict[str, Any]
):
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
@ -222,7 +224,7 @@ class ExternalDatasetService:
return response
@staticmethod
def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]:
def assembling_headers(authorization: Authorization, headers: dict[str, Any] | None = None) -> dict[str, Any]:
authorization = deepcopy(authorization)
if headers:
headers = deepcopy(headers)
@ -248,11 +250,11 @@ class ExternalDatasetService:
return headers
@staticmethod
def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting:
def get_external_knowledge_api_settings(settings: dict[str, Any]) -> ExternalKnowledgeApiSetting:
return ExternalKnowledgeApiSetting.model_validate(settings)
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
def create_external_dataset(tenant_id: str, user_id: str, args: dict[str, Any]) -> Dataset:
# check if dataset name already exists
if db.session.scalar(
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
@ -304,7 +306,7 @@ class ExternalDatasetService:
tenant_id: str,
dataset_id: str,
query: str,
external_retrieval_parameters: dict,
external_retrieval_parameters: dict[str, Any],
metadata_condition: MetadataFilteringCondition | None = None,
):
external_knowledge_binding = db.session.scalar(

View File

@ -44,7 +44,7 @@ class HitTestingService:
dataset: Dataset,
query: str,
account: Account,
retrieval_model: dict | None,
retrieval_model: dict[str, Any] | None,
external_retrieval_model: dict,
attachment_ids: list | None = None,
limit: int = 10,

View File

@ -1,8 +1,22 @@
import os
from typing import TypedDict
import httpx
class UtmInfo(TypedDict, total=False):
"""Expected shape of the utm_info dict passed to record_utm.
All fields are optional; missing keys default to an empty string.
"""
utm_source: str
utm_medium: str
utm_campaign: str
utm_content: str
utm_term: str
class OperationService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
@ -17,7 +31,7 @@ class OperationService:
return response.json()
@classmethod
def record_utm(cls, tenant_id: str, utm_info: dict):
def record_utm(cls, tenant_id: str, utm_info: UtmInfo):
params = {
"tenant_id": tenant_id,
"utm_source": utm_info.get("utm_source", ""),

View File

@ -1,3 +1,5 @@
from typing import Any
from sqlalchemy import select
from core.ops.entities.config_entity import BaseTracingConfig
@ -135,7 +137,7 @@ class OpsService:
return trace_config_data.to_dict()
@classmethod
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict[str, Any]):
"""
Create tracing app config
:param app_id: app id
@ -210,7 +212,7 @@ class OpsService:
return {"result": "success"}
@classmethod
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict[str, Any]):
"""
Update tracing app config
:param app_id: app id

View File

@ -1,9 +1,13 @@
from typing import Any
from core.plugin.impl.endpoint import PluginEndpointClient
class EndpointService:
@classmethod
def create_endpoint(cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict):
def create_endpoint(
cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict[str, Any]
):
return PluginEndpointClient().create_endpoint(
tenant_id=tenant_id,
user_id=user_id,
@ -32,7 +36,7 @@ class EndpointService:
)
@classmethod
def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]):
return PluginEndpointClient().update_endpoint(
tenant_id=tenant_id,
user_id=user_id,

View File

@ -1,6 +1,7 @@
import json
from os import path
from pathlib import Path
from typing import Any
from flask import current_app
@ -13,21 +14,21 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json
"""
builtin_data: dict | None = None
builtin_data: dict[str, Any] | None = None
def get_type(self) -> str:
return PipelineTemplateType.BUILTIN
def get_pipeline_templates(self, language: str) -> dict:
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
result = self.fetch_pipeline_templates_from_builtin(language)
return result
def get_pipeline_template_detail(self, template_id: str):
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
result = self.fetch_pipeline_template_detail_from_builtin(template_id)
return result
@classmethod
def _get_builtin_data(cls) -> dict:
def _get_builtin_data(cls) -> dict[str, Any]:
"""
Get builtin data.
:return:
@ -43,7 +44,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return cls.builtin_data or {}
@classmethod
def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict:
def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict[str, Any]:
"""
Fetch pipeline templates from builtin.
:param language: language
@ -53,7 +54,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return builtin_data.get("pipeline_templates", {}).get(language, {})
@classmethod
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None:
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict[str, Any] | None:
"""
Fetch pipeline template detail from builtin.
:param template_id: Template ID

View File

@ -1,3 +1,5 @@
from typing import Any
import yaml
from sqlalchemy import select
@ -13,12 +15,12 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
Retrieval recommended app from database
"""
def get_pipeline_templates(self, language: str) -> dict:
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
_, current_tenant_id = current_account_with_tenant()
result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
return result
def get_pipeline_template_detail(self, template_id: str):
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
result = self.fetch_pipeline_template_detail_from_db(template_id)
return result
@ -26,7 +28,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return PipelineTemplateType.CUSTOMIZED
@classmethod
def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict:
def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict[str, Any]:
"""
Fetch pipeline templates from db.
:param tenant_id: tenant id
@ -53,7 +55,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict[str, Any] | None:
"""
Fetch pipeline template detail from db.
:param template_id: Template ID

View File

@ -1,3 +1,5 @@
from typing import Any
import yaml
from sqlalchemy import select
@ -12,11 +14,11 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
Retrieval pipeline template from database
"""
def get_pipeline_templates(self, language: str) -> dict:
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
result = self.fetch_pipeline_templates_from_db(language)
return result
def get_pipeline_template_detail(self, template_id: str):
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
result = self.fetch_pipeline_template_detail_from_db(template_id)
return result
@ -24,7 +26,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return PipelineTemplateType.DATABASE
@classmethod
def fetch_pipeline_templates_from_db(cls, language: str) -> dict:
def fetch_pipeline_templates_from_db(cls, language: str) -> dict[str, Any]:
"""
Fetch pipeline templates from db.
:param language: language
@ -54,7 +56,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict[str, Any] | None:
"""
Fetch pipeline template detail from db.
:param pipeline_id: Pipeline ID

View File

@ -1,15 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any
class PipelineTemplateRetrievalBase(ABC):
"""Interface for pipeline template retrieval."""
@abstractmethod
def get_pipeline_templates(self, language: str) -> dict:
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
raise NotImplementedError
@abstractmethod
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
raise NotImplementedError
@abstractmethod

View File

@ -1,4 +1,5 @@
import logging
from typing import Any
import httpx
@ -15,8 +16,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
Retrieval recommended app from dify official
"""
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
result: dict | None
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
result: dict[str, Any] | None
try:
result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
except Exception as e:
@ -24,7 +25,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
return result
def get_pipeline_templates(self, language: str) -> dict:
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
try:
result = self.fetch_pipeline_templates_from_dify_official(language)
except Exception as e:
@ -36,7 +37,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return PipelineTemplateType.REMOTE
@classmethod
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict:
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict[str, Any]:
"""
Fetch pipeline template detail from dify official.
@ -53,11 +54,11 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
+ f" status_code: {response.status_code},"
+ f" response: {response.text[:1000]}"
)
data: dict = response.json()
data: dict[str, Any] = response.json()
return data
@classmethod
def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict:
def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict[str, Any]:
"""
Fetch pipeline templates from dify official.
:param language: language
@ -69,6 +70,6 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
if response.status_code != 200:
raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}")
result: dict = response.json()
result: dict[str, Any] = response.json()
return result

View File

@ -92,7 +92,7 @@ class ApiToolManageService:
@staticmethod
def convert_schema_to_tool_bundles(
schema: str, extra_info: dict | None = None
schema: str, extra_info: dict[str, Any] | None = None
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
convert schema to tool bundles
@ -109,8 +109,8 @@ class ApiToolManageService:
user_id: str,
tenant_id: str,
provider_name: str,
icon: dict,
credentials: dict,
icon: dict[str, Any],
credentials: dict[str, Any],
schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str,
@ -244,8 +244,8 @@ class ApiToolManageService:
tenant_id: str,
provider_name: str,
original_provider: str,
icon: dict,
credentials: dict,
icon: dict[str, Any],
credentials: dict[str, Any],
_schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str | None,
@ -356,8 +356,8 @@ class ApiToolManageService:
tenant_id: str,
provider_name: str,
tool_name: str,
credentials: dict,
parameters: dict,
credentials: dict[str, Any],
parameters: dict[str, Any],
schema_type: ApiProviderSchemaType,
schema: str,
):

View File

@ -147,7 +147,7 @@ class BuiltinToolManageService:
tenant_id: str,
provider: str,
credential_id: str,
credentials: dict | None = None,
credentials: dict[str, Any] | None = None,
name: str | None = None,
):
"""
@ -177,7 +177,7 @@ class BuiltinToolManageService:
)
original_credentials = encrypter.decrypt(db_provider.credentials)
new_credentials: dict = {
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
@ -216,7 +216,7 @@ class BuiltinToolManageService:
api_type: CredentialType,
tenant_id: str,
provider: str,
credentials: dict,
credentials: dict[str, Any],
expires_at: int = -1,
name: str | None = None,
):
@ -657,7 +657,7 @@ class BuiltinToolManageService:
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: dict | None = None,
client_params: dict[str, Any] | None = None,
enable_oauth_custom_client: bool | None = None,
):
"""

View File

@ -69,7 +69,9 @@ class ToolTransformService:
return ""
@staticmethod
def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity):
def repack_provider(
tenant_id: str, provider: dict[str, Any] | ToolProviderApiEntity | PluginDatasourceProviderEntity
):
"""
repack provider

View File

@ -1,6 +1,7 @@
import json
import logging
from datetime import datetime
from typing import Any
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import delete, or_, select
@ -35,7 +36,7 @@ class WorkflowToolManageService:
workflow_app_id: str,
name: str,
label: str,
icon: dict,
icon: dict[str, Any],
description: str,
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",
@ -117,7 +118,7 @@ class WorkflowToolManageService:
workflow_tool_id: str,
name: str,
label: str,
icon: dict,
icon: dict[str, Any],
description: str,
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",

View File

@ -91,7 +91,7 @@ class WebsiteCrawlApiRequest:
return CrawlRequest(url=self.url, provider=self.provider, options=options)
@classmethod
def from_args(cls, args: dict) -> WebsiteCrawlApiRequest:
def from_args(cls, args: dict[str, Any]) -> WebsiteCrawlApiRequest:
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
url = args.get("url")
@ -115,7 +115,7 @@ class WebsiteCrawlStatusApiRequest:
job_id: str
@classmethod
def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest:
def from_args(cls, args: dict[str, Any], job_id: str) -> WebsiteCrawlStatusApiRequest:
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
if not provider:
@ -163,7 +163,7 @@ class WebsiteService:
raise ValueError("Invalid provider")
@classmethod
def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
def _get_decrypted_api_key(cls, tenant_id: str, config: dict[str, Any]) -> str:
"""Decrypt and return the API key from config."""
api_key = config.get("api_key")
if not api_key:
@ -171,7 +171,7 @@ class WebsiteService:
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
@classmethod
def document_create_args_validate(cls, args: dict):
def document_create_args_validate(cls, args: dict[str, Any]):
"""Validate arguments for document creation."""
try:
WebsiteCrawlApiRequest.from_args(args)
@ -195,7 +195,7 @@ class WebsiteService:
raise ValueError("Invalid provider")
@classmethod
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
params: dict[str, Any]
@ -225,7 +225,7 @@ class WebsiteService:
return {"status": "active", "job_id": job_id}
@classmethod
def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
# Convert CrawlOptions back to dict format for WaterCrawlProvider
options = {
"limit": request.options.limit,
@ -290,7 +290,7 @@ class WebsiteService:
raise ValueError("Invalid provider")
@classmethod
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> CrawlStatusDict:
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> CrawlStatusDict:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id)
crawl_status_data: CrawlStatusDict = {
@ -364,7 +364,9 @@ class WebsiteService:
raise ValueError("Invalid provider")
@classmethod
def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
def _get_firecrawl_url_data(
cls, job_id: str, url: str, api_key: str, config: dict[str, Any]
) -> dict[str, Any] | None:
crawl_data: list[FirecrawlDocumentData] | None = None
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
@ -438,7 +440,7 @@ class WebsiteService:
raise ValueError("Invalid provider")
@classmethod
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
params = {"onlyMainContent": request.only_main_content}
return dict(firecrawl_app.scrape_url(url=request.url, params=params))

View File

@ -1,5 +1,6 @@
import threading
from collections.abc import Sequence
from typing import TypedDict
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
@ -19,6 +20,14 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
class WorkflowRunListArgs(TypedDict, total=False):
"""Expected shape of the args dict passed to workflow run pagination methods."""
limit: int
last_id: str
status: str
class WorkflowRunService:
_session_factory: sessionmaker
_workflow_run_repo: APIWorkflowRunRepository
@ -37,7 +46,10 @@ class WorkflowRunService:
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory)
def get_paginate_advanced_chat_workflow_runs(
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
self,
app_model: App,
args: WorkflowRunListArgs,
triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING,
) -> InfiniteScrollPagination:
"""
Get advanced chat app workflow run list
@ -73,7 +85,10 @@ class WorkflowRunService:
return pagination
def get_paginate_workflow_runs(
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
self,
app_model: App,
args: WorkflowRunListArgs,
triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING,
) -> InfiniteScrollPagination:
"""
Get workflow run list

View File

@ -11,6 +11,7 @@ from uuid import uuid4
import pytest
from faker import Faker
from sqlalchemy import delete
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_redis import redis_client
@ -28,12 +29,12 @@ class TestCreateSegmentToIndexTask:
"""Clean up database and Redis before each test to ensure isolation."""
# Clear all test data using fixture session
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(DocumentSegment))
db_session_with_containers.execute(delete(Document))
db_session_with_containers.execute(delete(Dataset))
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()
# Clear Redis cache

View File

@ -14,6 +14,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import delete
from libs.email_i18n import EmailType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -41,9 +42,9 @@ class TestSendEmailCodeLoginMailTask:
from extensions.ext_redis import redis_client
# Clear all test data
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()
# Clear Redis cache

View File

@ -6,6 +6,7 @@ import pytest
from graphon.enums import WorkflowExecutionStatus
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from sqlalchemy import delete
from configs import dify_config
from core.app.app_config.entities import WorkflowUIBasedAppConfig
@ -30,14 +31,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
@pytest.fixture(autouse=True)
def cleanup_database(db_session_with_containers):
db_session_with_containers.query(HumanInputFormRecipient).delete()
db_session_with_containers.query(HumanInputDelivery).delete()
db_session_with_containers.query(HumanInputForm).delete()
db_session_with_containers.query(WorkflowPause).delete()
db_session_with_containers.query(WorkflowRun).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(HumanInputFormRecipient))
db_session_with_containers.execute(delete(HumanInputDelivery))
db_session_with_containers.execute(delete(HumanInputForm))
db_session_with_containers.execute(delete(WorkflowPause))
db_session_with_containers.execute(delete(WorkflowRun))
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()

View File

@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import delete, select
from extensions.ext_redis import redis_client
from libs.email_i18n import EmailType
@ -44,9 +45,9 @@ class TestMailInviteMemberTask:
def cleanup_database(self, db_session_with_containers):
"""Clean up database before each test to ensure isolation."""
# Clear all test data
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()
# Clear Redis cache
@ -491,10 +492,10 @@ class TestMailInviteMemberTask:
assert tenant.name is not None
# Verify tenant relationship exists
tenant_join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=pending_account.id)
.first()
tenant_join = db_session_with_containers.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == pending_account.id)
.limit(1)
)
assert tenant_join is not None
assert tenant_join.role == TenantAccountRole.NORMAL

View File

@ -4,6 +4,7 @@ from unittest.mock import ANY, call, patch
import pytest
from graphon.variables.segments import StringSegment
from graphon.variables.types import SegmentType
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from extensions.storage.storage_type import StorageType
@ -20,11 +21,11 @@ from tasks.remove_app_and_related_data_task import (
@pytest.fixture(autouse=True)
def cleanup_database(db_session_with_containers):
db_session_with_containers.query(WorkflowDraftVariable).delete()
db_session_with_containers.query(WorkflowDraftVariableFile).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.execute(delete(WorkflowDraftVariable))
db_session_with_containers.execute(delete(WorkflowDraftVariableFile))
db_session_with_containers.execute(delete(UploadFile))
db_session_with_containers.execute(delete(App))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.commit()
@ -127,21 +128,21 @@ class TestDeleteDraftVariablesBatch:
result = delete_draft_variables_batch(app1.id, batch_size=100)
assert result == 150
app1_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app1.id
app1_remaining_count = db_session_with_containers.scalar(
select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app1.id)
)
app2_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app2.id
app2_remaining_count = db_session_with_containers.scalar(
select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app2.id)
)
assert app1_remaining.count() == 0
assert app2_remaining.count() == 100
assert app1_remaining_count == 0
assert app2_remaining_count == 100
def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers):
"""Test deletion when no draft variables exist for the app."""
result = delete_draft_variables_batch(str(uuid.uuid4()), 1000)
assert result == 0
assert db_session_with_containers.query(WorkflowDraftVariable).count() == 0
assert db_session_with_containers.scalar(select(func.count()).select_from(WorkflowDraftVariable)) == 0
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.logger")
@ -190,12 +191,16 @@ class TestDeleteDraftVariableOffloadData:
expected_storage_calls = [call(storage_key) for storage_key in upload_file_keys]
mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True)
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
WorkflowDraftVariableFile.id.in_(file_ids)
remaining_var_files_count = db_session_with_containers.scalar(
select(func.count())
.select_from(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(file_ids))
)
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
assert remaining_var_files.count() == 0
assert remaining_upload_files.count() == 0
remaining_upload_files_count = db_session_with_containers.scalar(
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
)
assert remaining_var_files_count == 0
assert remaining_upload_files_count == 0
@patch("extensions.ext_storage.storage")
@patch("tasks.remove_app_and_related_data_task.logging")
@ -217,9 +222,13 @@ class TestDeleteDraftVariableOffloadData:
assert result == 1
mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", storage_keys[0])
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
WorkflowDraftVariableFile.id.in_(file_ids)
remaining_var_files_count = db_session_with_containers.scalar(
select(func.count())
.select_from(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(file_ids))
)
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
assert remaining_var_files.count() == 0
assert remaining_upload_files.count() == 0
remaining_upload_files_count = db_session_with_containers.scalar(
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
)
assert remaining_var_files_count == 0
assert remaining_upload_files_count == 0

View File

@ -11,6 +11,7 @@ from collections.abc import Generator
from typing import Any
import pytest
from sqlalchemy import delete
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -40,9 +41,9 @@ def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[T
yield tenant, account
# Cleanup
db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
db_session_with_containers.query(Account).filter_by(id=account.id).delete()
db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete()
db_session_with_containers.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == tenant.id))
db_session_with_containers.execute(delete(Account).where(Account.id == account.id))
db_session_with_containers.execute(delete(Tenant).where(Tenant.id == tenant.id))
db_session_with_containers.commit()
@ -93,14 +94,14 @@ def app_model(
)
from models.workflow import Workflow
db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete()
db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete()
db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete()
db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete()
db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete()
db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete()
db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete()
db_session_with_containers.query(App).filter_by(id=app.id).delete()
db_session_with_containers.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app.id))
db_session_with_containers.execute(delete(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app.id))
db_session_with_containers.execute(delete(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id))
db_session_with_containers.execute(delete(WorkflowPluginTrigger).where(WorkflowPluginTrigger.app_id == app.id))
db_session_with_containers.execute(delete(AppTrigger).where(AppTrigger.app_id == app.id))
db_session_with_containers.execute(delete(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant.id))
db_session_with_containers.execute(delete(Workflow).where(Workflow.app_id == app.id))
db_session_with_containers.execute(delete(App).where(App.id == app.id))
db_session_with_containers.commit()

View File

@ -11,6 +11,7 @@ import pytest
from flask import Flask, Response
from flask.testing import FlaskClient
from graphon.enums import BuiltinNodeTypes
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@ -227,7 +228,9 @@ def test_webhook_trigger_creates_trigger_log(
assert response.status_code == 200
db_session_with_containers.expire_all()
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
logs = db_session_with_containers.scalars(
select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id)
).all()
assert logs, "Webhook trigger should create trigger log"
@ -611,7 +614,9 @@ def test_schedule_trigger_creates_trigger_log(
# Verify WorkflowTriggerLog was created
db_session_with_containers.expire_all()
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
logs = db_session_with_containers.scalars(
select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id)
).all()
assert logs, "Schedule trigger should create WorkflowTriggerLog"
assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE
assert logs[0].root_node_id == schedule_node_id
@ -786,11 +791,12 @@ def test_plugin_trigger_full_chain_with_db_verification(
# Verify database records exist
db_session_with_containers.expire_all()
plugin_triggers = (
db_session_with_containers.query(WorkflowPluginTrigger)
.filter_by(app_id=app_model.id, node_id=plugin_node_id)
.all()
)
plugin_triggers = db_session_with_containers.scalars(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.app_id == app_model.id,
WorkflowPluginTrigger.node_id == plugin_node_id,
)
).all()
assert plugin_triggers, "WorkflowPluginTrigger record should exist"
assert plugin_triggers[0].provider_id == provider_id
assert plugin_triggers[0].event_name == "test_event"

View File

@ -14,18 +14,20 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from werkzeug.exceptions import Unauthorized
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailPasswordLoginLimitError,
InvalidEmailError,
)
from controllers.console.auth.login import LoginApi, LogoutApi
from controllers.console.auth.login import EmailCodeLoginApi, LoginApi, LogoutApi
from controllers.console.error import (
AccountBannedError,
AccountInFreezeError,
WorkspacesLimitExceeded,
)
from services.entities.auth_entities import LoginFailureReason
from services.errors.account import AccountLoginError, AccountPasswordError
@ -34,6 +36,11 @@ def encode_password(password: str) -> str:
return base64.b64encode(password.encode("utf-8")).decode()
def encode_code(code: str) -> str:
"""Helper to encode verification code as Base64 for testing."""
return base64.b64encode(code.encode("utf-8")).decode()
class TestLoginApi:
"""Test cases for the LoginApi endpoint."""
@ -197,12 +204,17 @@ class TestLoginApi:
mock_get_invitation.return_value = None
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(EmailPasswordLoginLimitError):
login_api.post()
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(EmailPasswordLoginLimitError):
login_api.post()
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "test@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.LOGIN_RATE_LIMITED
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True)
@ -220,12 +232,17 @@ class TestLoginApi:
mock_is_frozen.return_value = True
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(AccountInFreezeError):
login_api.post()
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(AccountInFreezeError):
login_api.post()
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "frozen@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_IN_FREEZE
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@ -257,14 +274,20 @@ class TestLoginApi:
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("WrongPass123!")}
):
login_api = LoginApi()
with pytest.raises(AuthenticationFailedError):
login_api.post()
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/login",
method="POST",
json={"email": "test@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()
with pytest.raises(AuthenticationFailedError):
login_api.post()
mock_add_rate_limit.assert_called_once_with("test@example.com")
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "test@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@ -288,12 +311,19 @@ class TestLoginApi:
mock_authenticate.side_effect = AccountLoginError("Account is banned")
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "banned@example.com", "password": encode_password("ValidPass123!")}
):
login_api = LoginApi()
with pytest.raises(AccountBannedError):
login_api.post()
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/login",
method="POST",
json={"email": "banned@example.com", "password": encode_password("ValidPass123!")},
):
login_api = LoginApi()
with pytest.raises(AccountBannedError):
login_api.post()
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "banned@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@ -417,6 +447,36 @@ class TestLoginApi:
mock_add_rate_limit.assert_not_called()
mock_reset_rate_limit.assert_called_once_with("upper@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
@patch("controllers.console.auth.login._get_account_with_case_fallback")
def test_email_code_login_logs_banned_account(
self,
mock_get_account,
mock_revoke_token,
mock_get_token_data,
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_account.side_effect = Unauthorized("Account is banned.")
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
):
with pytest.raises(AccountBannedError):
EmailCodeLoginApi().post()
mock_revoke_token.assert_called_once_with("token-123")
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "user@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
class TestLogoutApi:
"""Test cases for the LogoutApi endpoint."""

View File

@ -4,9 +4,12 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from jwt import InvalidTokenError
from werkzeug.exceptions import Unauthorized
import services.errors.account
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi
from services.entities.auth_entities import LoginFailureReason
def encode_code(code: str) -> str:
@ -115,13 +118,18 @@ class TestLoginApi:
def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None:
from controllers.console.error import AccountBannedError
with app.test_request_context(
"/web/login",
method="POST",
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
):
with pytest.raises(AccountBannedError):
LoginApi().post()
with patch("controllers.web.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/web/login",
method="POST",
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
):
with pytest.raises(AccountBannedError):
LoginApi().post()
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "user@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
@patch(
"controllers.web.login.WebAppAuthService.authenticate",
@ -130,13 +138,87 @@ class TestLoginApi:
def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None:
from controllers.console.auth.error import AuthenticationFailedError
with app.test_request_context(
"/web/login",
method="POST",
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
):
with pytest.raises(AuthenticationFailedError):
LoginApi().post()
with patch("controllers.web.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/web/login",
method="POST",
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
):
with pytest.raises(AuthenticationFailedError):
LoginApi().post()
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "user@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS
@patch(
"controllers.web.login.WebAppAuthService.authenticate",
side_effect=services.errors.account.AccountNotFoundError(),
)
def test_login_account_not_found(self, mock_auth: MagicMock, app: Flask) -> None:
from controllers.console.auth.error import AuthenticationFailedError
with patch("controllers.web.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/web/login",
method="POST",
json={"email": "missing@example.com", "password": base64.b64encode(b"Valid1234").decode()},
):
with pytest.raises(AuthenticationFailedError):
LoginApi().post()
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "missing@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_NOT_FOUND
@patch("controllers.web.login.WebAppAuthService.get_email_code_login_data", return_value=None)
def test_email_code_login_logs_invalid_token(self, mock_get_token_data: MagicMock, app: Flask) -> None:
with patch("controllers.web.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/web/email-code-login/validity",
method="POST",
json={"email": "user@example.com", "code": encode_code("123456"), "token": "token-123"},
):
with pytest.raises(InvalidTokenError):
EmailCodeLoginApi().post()
mock_get_token_data.assert_called_once_with("token-123")
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "user@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_EMAIL_CODE_TOKEN
@patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token")
@patch(
"controllers.web.login.WebAppAuthService.get_user_through_email",
side_effect=Unauthorized("Account is banned."),
)
@patch(
"controllers.web.login.WebAppAuthService.get_email_code_login_data",
return_value={"email": "User@Example.com", "code": "123456"},
)
def test_email_code_login_logs_banned_account(
self,
mock_get_token_data: MagicMock,
mock_get_user: MagicMock,
mock_revoke_token: MagicMock,
app: Flask,
) -> None:
from controllers.console.error import AccountBannedError
with patch("controllers.web.login.logger.warning") as mock_log_warning:
with app.test_request_context(
"/web/email-code-login/validity",
method="POST",
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
):
with pytest.raises(AccountBannedError):
EmailCodeLoginApi().post()
mock_get_token_data.assert_called_once_with("token-123")
mock_revoke_token.assert_called_once_with("token-123")
assert mock_log_warning.call_count == 1
assert mock_log_warning.call_args.args[1] == "user@example.com"
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
class TestLoginStatusApi:

63
pnpm-lock.yaml generated
View File

@ -16,14 +16,17 @@ catalogs:
specifier: 8.1.1
version: 8.1.1
'@base-ui/react':
specifier: 1.3.0
version: 1.3.0
specifier: 1.4.0
version: 1.4.0
'@chromatic-com/storybook':
specifier: 5.1.1
version: 5.1.1
'@cucumber/cucumber':
specifier: 12.7.0
version: 12.7.0
'@date-fns/tz':
specifier: 1.2.0
version: 1.2.0
'@egoist/tailwindcss-icons':
specifier: 1.9.2
version: 1.9.2
@ -267,6 +270,9 @@ catalogs:
cron-parser:
specifier: 5.5.0
version: 5.5.0
date-fns:
specifier: 4.0.0
version: 4.0.0
dayjs:
specifier: 1.11.20
version: 1.11.20
@ -433,8 +439,8 @@ catalogs:
specifier: 5.2.4
version: 5.2.4
react-i18next:
specifier: 17.0.2
version: 17.0.2
specifier: 16.5.8
version: 16.5.8
react-multi-email:
specifier: 1.0.25
version: 1.0.25
@ -648,7 +654,10 @@ importers:
version: 1.27.6(@amplitude/rrweb@2.0.0-alpha.37)(rollup@4.59.0)
'@base-ui/react':
specifier: 'catalog:'
version: 1.3.0(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
version: 1.4.0(@date-fns/tz@1.2.0)(@types/react@19.2.14)(date-fns@4.0.0)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
'@date-fns/tz':
specifier: 'catalog:'
version: 1.2.0
'@emoji-mart/data':
specifier: 'catalog:'
version: 1.2.1
@ -751,6 +760,9 @@ importers:
cron-parser:
specifier: 'catalog:'
version: 5.5.0
date-fns:
specifier: 'catalog:'
version: 4.0.0
dayjs:
specifier: 'catalog:'
version: 1.11.20
@ -876,7 +888,7 @@ importers:
version: 5.2.4(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
react-i18next:
specifier: 'catalog:'
version: 17.0.2(i18next@26.0.4(typescript@6.0.2))(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(typescript@6.0.2)
version: 16.5.8(i18next@26.0.4(typescript@6.0.2))(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(typescript@6.0.2)
react-multi-email:
specifier: 'catalog:'
version: 1.0.25(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
@ -1391,19 +1403,21 @@ packages:
resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==}
engines: {node: '>=6.9.0'}
'@base-ui/react@1.3.0':
resolution: {integrity: sha512-FwpKqZbPz14AITp1CVgf4AjhKPe1OeeVKSBMdgD10zbFlj3QSWelmtCMLi2+/PFZZcIm3l87G7rwtCZJwHyXWA==}
'@base-ui/react@1.4.0':
resolution: {integrity: sha512-QcqdVbr/+ba2/RAKJIV1PV6S02Q5+r6a4Eym8ndBw+ZbBILkkmQAyRxXCg/pArrHnkrGeU8goe26aw0h6eE8pg==}
engines: {node: '>=14.0.0'}
peerDependencies:
'@date-fns/tz': ^1.2.0
'@types/react': ^17 || ^18 || ^19
date-fns: ^4.0.0
react: ^17 || ^18 || ^19
react-dom: ^17 || ^18 || ^19
peerDependenciesMeta:
'@types/react':
optional: true
'@base-ui/utils@0.2.6':
resolution: {integrity: sha512-yQ+qeuqohwhsNpoYDqqXaLllYAkPCP4vYdDrVo8FQXaAPfHWm1pG/Vm+jmGTA5JFS0BAIjookyapuJFY8F9PIw==}
'@base-ui/utils@0.2.7':
resolution: {integrity: sha512-nXYKhiL/0JafyJE8PfcflipGftOftlIwKd72rU15iZ1M5yqgg5J9P8NHU71GReDuXco5MJA/eVQqUT5WRqX9sA==}
peerDependencies:
'@types/react': ^17 || ^18 || ^19
react: ^17 || ^18 || ^19
@ -1532,6 +1546,9 @@ packages:
'@cucumber/tag-expressions@9.1.0':
resolution: {integrity: sha512-bvHjcRFZ+J1TqIa9eFNO1wGHqwx4V9ZKV3hYgkuK/VahHx73uiP4rKV3JVrvWSMrwrFvJG6C8aEwnCWSvbyFdQ==}
'@date-fns/tz@1.2.0':
resolution: {integrity: sha512-LBrd7MiJZ9McsOgxqWX7AaxrDjcFVjWH/tIKJd7pnR7McaslGYOP1QmmiBXdJH/H/yLCT+rcQ7FaPBUxRGUtrg==}
'@e18e/eslint-plugin@0.3.0':
resolution: {integrity: sha512-hHgfpxsrZ2UYHcicA+tGZnmk19uJTaye9VH79O+XS8R4ona2Hx3xjhXghclNW58uXMk3xXlbYEOMr8thsoBmWg==}
peerDependencies:
@ -5325,6 +5342,9 @@ packages:
dagre-d3-es@7.0.14:
resolution: {integrity: sha512-P4rFMVq9ESWqmOgK+dlXvOtLwYg0i7u0HBGJER0LZDJT2VHIPAMZ/riPxqJceWMStH5+E61QxFra9kIS3AqdMg==}
date-fns@4.0.0:
resolution: {integrity: sha512-6K33+I8fQ5otvHgLIvKK1xmMbLAh0pduyrx7dwMXKiGYeoWhmk6M3Zoak9n7bXHMJQlHq1yqmdGy1QxKddJjUA==}
dayjs@1.11.20:
resolution: {integrity: sha512-YbwwqR/uYpeoP4pu043q+LTDLFBLApUP6VxRihdfNTqu4ubqMlGDLd6ErXhEgsyvY0K6nCs7nggYumAN+9uEuQ==}
@ -7376,14 +7396,14 @@ packages:
react: '>=16.8.0'
react-dom: '>=16.8.0'
react-i18next@17.0.2:
resolution: {integrity: sha512-shBftH2vaTWK2Bsp7FiL+cevx3xFJlvFxmsDFQSrJc+6twHkP0tv/bGa01VVWzpreUVVwU+3Hev5iFqRg65RwA==}
react-i18next@16.5.8:
resolution: {integrity: sha512-2ABeHHlakxVY+LSirD+OiERxFL6+zip0PaHo979bgwzeHg27Sqc82xxXWIrSFmfWX0ZkrvXMHwhsi/NGUf5VQg==}
peerDependencies:
i18next: '>= 26.0.1'
i18next: '>= 25.6.2'
react: '>= 16.8.0'
react-dom: '*'
react-native: '*'
typescript: ^5 || ^6
typescript: ^5
peerDependenciesMeta:
react-dom:
optional: true
@ -8896,20 +8916,21 @@ snapshots:
'@babel/helper-string-parser': 7.27.1
'@babel/helper-validator-identifier': 7.28.5
'@base-ui/react@1.3.0(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
'@base-ui/react@1.4.0(@date-fns/tz@1.2.0)(@types/react@19.2.14)(date-fns@4.0.0)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
dependencies:
'@babel/runtime': 7.29.2
'@base-ui/utils': 0.2.6(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
'@base-ui/utils': 0.2.7(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
'@date-fns/tz': 1.2.0
'@floating-ui/react-dom': 2.1.8(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
'@floating-ui/utils': 0.2.11
date-fns: 4.0.0
react: 19.2.5
react-dom: 19.2.5(react@19.2.5)
tabbable: 6.4.0
use-sync-external-store: 1.6.0(react@19.2.5)
optionalDependencies:
'@types/react': 19.2.14
'@base-ui/utils@0.2.6(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
'@base-ui/utils@0.2.7(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
dependencies:
'@babel/runtime': 7.29.2
'@floating-ui/utils': 0.2.11
@ -9128,6 +9149,8 @@ snapshots:
'@cucumber/tag-expressions@9.1.0': {}
'@date-fns/tz@1.2.0': {}
'@e18e/eslint-plugin@0.3.0(eslint@10.2.0(jiti@2.6.1))(oxlint@1.58.0(oxlint-tsgolint@0.20.0))':
dependencies:
eslint-plugin-depend: 1.5.0(eslint@10.2.0(jiti@2.6.1))
@ -12793,6 +12816,8 @@ snapshots:
d3: 7.9.0
lodash-es: 4.18.0
date-fns@4.0.0: {}
dayjs@1.11.20: {}
debug@4.4.3(supports-color@8.1.1):
@ -15458,7 +15483,7 @@ snapshots:
react: 19.2.5
react-dom: 19.2.5(react@19.2.5)
react-i18next@17.0.2(i18next@26.0.4(typescript@6.0.2))(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(typescript@6.0.2):
react-i18next@16.5.8(i18next@26.0.4(typescript@6.0.2))(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(typescript@6.0.2):
dependencies:
'@babel/runtime': 7.29.2
html-parse-stringify: 3.0.1

View File

@ -48,7 +48,8 @@ catalog:
"@amplitude/analytics-browser": 2.38.1
"@amplitude/plugin-session-replay-browser": 1.27.6
"@antfu/eslint-config": 8.1.1
"@base-ui/react": 1.3.0
"@base-ui/react": 1.4.0
"@date-fns/tz": 1.2.0
"@chromatic-com/storybook": 5.1.1
"@cucumber/cucumber": 12.7.0
"@egoist/tailwindcss-icons": 1.9.2
@ -135,6 +136,7 @@ catalog:
code-inspector-plugin: 1.5.1
copy-to-clipboard: 3.3.3
cron-parser: 5.5.0
date-fns: 4.0.0
dayjs: 1.11.20
decimal.js: 10.6.0
dompurify: 3.3.3
@ -191,7 +193,7 @@ catalog:
react-dom: 19.2.5
react-easy-crop: 5.5.7
react-hotkeys-hook: 5.2.4
react-i18next: 17.0.2
react-i18next: 16.5.8
react-multi-email: 1.0.25
react-papaparse: 4.4.0
react-pdf-highlighter: 8.0.0-rc.0

View File

@ -41,6 +41,33 @@ const renderOpenSelect = ({
}
describe('Select wrappers', () => {
describe('Select root integration', () => {
it('should associate the hidden input with an external form and preserve autocomplete hints', () => {
const formId = 'profile-form'
const { container } = render(
<>
<form id={formId} />
<Select defaultValue="seattle" name="city" form={formId} autoComplete="address-level2">
<SelectTrigger aria-label="city select">
<SelectValue />
</SelectTrigger>
<SelectContent listProps={{ 'role': 'listbox', 'aria-label': 'select list' }}>
<SelectItem value="seattle">Seattle</SelectItem>
<SelectItem value="new-york">New York</SelectItem>
</SelectContent>
</Select>
</>,
)
const hiddenInput = container.querySelector('input[name="city"]')
const form = container.querySelector(`#${formId}`) as HTMLFormElement
expect(hiddenInput).toHaveAttribute('form', formId)
expect(hiddenInput).toHaveAttribute('autocomplete', 'address-level2')
expect(new FormData(form).get('city')).toBe('seattle')
})
})
describe('SelectTrigger', () => {
it('should render clear button when clearable is true and loading is false', () => {
renderOpenSelect({

View File

@ -47,6 +47,23 @@ describe('Slider', () => {
expect(onValueChange).toHaveBeenLastCalledWith(21, expect.anything())
})
it('should round floating point keyboard updates to the configured step', async () => {
const user = userEvent.setup()
const onValueChange = vi.fn()
render(<Slider value={0.2} min={0} max={1} step={0.1} onValueChange={onValueChange} aria-label="Value" />)
const slider = getSliderInput()
await act(async () => {
slider.focus()
await user.keyboard('{ArrowRight}')
})
expect(onValueChange).toHaveBeenCalledTimes(1)
expect(onValueChange).toHaveBeenLastCalledWith(0.3, expect.anything())
})
it('should not trigger onValueChange when disabled', async () => {
const user = userEvent.setup()
const onValueChange = vi.fn()

View File

@ -251,6 +251,32 @@ describe('base/ui/toast', () => {
expect(screen.queryByText('Loading')).not.toBeInTheDocument()
})
// Re-adding the same toast id should upsert in place instead of stacking duplicates.
it('should upsert an existing toast when add is called with the same id', async () => {
render(<ToastHost />)
act(() => {
toast('Syncing', {
id: 'sync-job',
description: 'Uploading changes…',
})
})
expect(await screen.findByText('Syncing')).toBeInTheDocument()
act(() => {
toast.success('Synced', {
id: 'sync-job',
description: 'All changes are uploaded.',
})
})
expect(screen.queryByText('Syncing')).not.toBeInTheDocument()
expect(screen.getByText('Synced')).toBeInTheDocument()
expect(screen.getByText('All changes are uploaded.')).toBeInTheDocument()
expect(screen.getAllByRole('dialog')).toHaveLength(1)
})
// Action props should pass through to the Base UI action button.
it('should render and invoke toast action props', async () => {
const onAction = vi.fn()

View File

@ -501,6 +501,34 @@ describe('http path', () => {
expect(onChange).toHaveBeenCalled()
})
it('should only append a new key-value row after the last value field receives content', () => {
const onChange = vi.fn()
const onRemove = vi.fn()
const onAdd = vi.fn()
render(
<KeyValueItem
instanceId="kv-append"
nodeId="node-1"
readonly={false}
canRemove
payload={{ id: 'kv-append', key: 'name', value: '', type: 'text' } as any}
onChange={onChange}
onRemove={onRemove}
isLastItem
onAdd={onAdd}
/>,
)
const valueInput = screen.getAllByPlaceholderText('workflow.nodes.http.insertVarPlaceholder')[1]!
fireEvent.click(valueInput)
expect(onAdd).not.toHaveBeenCalled()
fireEvent.change(valueInput, { target: { value: 'alice' } })
expect(onChange).toHaveBeenCalledWith(expect.objectContaining({ value: 'alice' }))
expect(onAdd).toHaveBeenCalledTimes(1)
})
it('should edit key-only rows and select file payload rows', async () => {
const user = userEvent.setup()
const onChange = vi.fn()

View File

@ -47,20 +47,37 @@ const KeyValueItem: FC<Props> = ({
insertVarTipToLeft,
}) => {
const { t } = useTranslation()
const hasValuePayload = payload.type === 'file'
? !!payload.file?.length
: !!payload.value
const handleChange = useCallback((key: string) => {
return (value: string | ValueSelector) => {
const shouldAddNextItem = isLastItem
&& (
(key === 'value' && !payload.value && !!value)
|| (key === 'file' && (!payload.file || payload.file.length === 0) && Array.isArray(value) && value.length > 0)
)
const newPayload = produce(payload, (draft: any) => {
draft[key] = value
})
onChange(newPayload)
if (shouldAddNextItem)
onAdd()
}
}, [onChange, payload])
}, [isLastItem, onAdd, onChange, payload])
const filterOnlyFileVariable = (varPayload: Var) => {
return [VarType.file, VarType.arrayFile].includes(varPayload.type)
}
const handleValueContainerClick = useCallback(() => {
if (isLastItem && hasValuePayload)
onAdd()
}, [hasValuePayload, isLastItem, onAdd])
return (
// group class name is for hover row show remove button
<div className={cn(className, 'h-min-7 group flex border-t border-divider-regular')}>
@ -102,7 +119,10 @@ const KeyValueItem: FC<Props> = ({
/>
</div>
)}
<div className={cn(isSupportFile ? 'grow' : 'w-1/2')} onClick={() => isLastItem && onAdd()}>
<div
className={cn(isSupportFile ? 'grow' : 'w-1/2')}
onClick={handleValueContainerClick}
>
{(isSupportFile && payload.type === 'file')
? (
<VarReferencePicker

View File

@ -56,6 +56,7 @@
"@amplitude/analytics-browser": "catalog:",
"@amplitude/plugin-session-replay-browser": "catalog:",
"@base-ui/react": "catalog:",
"@date-fns/tz": "catalog:",
"@emoji-mart/data": "catalog:",
"@floating-ui/react": "catalog:",
"@formatjs/intl-localematcher": "catalog:",
@ -90,6 +91,7 @@
"cmdk": "catalog:",
"copy-to-clipboard": "catalog:",
"cron-parser": "catalog:",
"date-fns": "catalog:",
"dayjs": "catalog:",
"decimal.js": "catalog:",
"dompurify": "catalog:",

View File

@ -1,11 +1,21 @@
import type { Schema } from 'jsonschema'
import type { Schema, ValidationError, ValidatorResult } from 'jsonschema'
import { Validator } from 'jsonschema'
import draft07Schema from './draft-07.json'
const validator = new Validator()
export const draft07Validator = (schema: any) => {
return validator.validate(schema, draft07Schema as unknown as Schema)
type Draft07ValidationResult = Pick<ValidatorResult, 'valid' | 'errors'>
export const draft07Validator = (schema: any): Draft07ValidationResult => {
try {
return validator.validate(schema, draft07Schema as unknown as Schema)
}
catch {
// The jsonschema library may throw URL errors in browser environments
// when resolving schema $id URIs. Return empty errors since structural
// validation is handled separately by preValidateSchema (#34841).
return { valid: true, errors: [] as ValidationError[] }
}
}
export const forbidBooleanProperties = (schema: any, path: string[] = []): string[] => {