mirror of
https://github.com/langgenius/dify.git
synced 2026-05-12 07:37:09 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
8747e3a2d3
7
.github/workflows/docker-build.yml
vendored
7
.github/workflows/docker-build.yml
vendored
@ -6,14 +6,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- packages/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
|
||||
@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class AdvancedPromptTemplateQuery(BaseModel):
|
||||
@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||
prompt_args: AdvancedPromptTemplateArgs = {
|
||||
"app_mode": args.app_mode,
|
||||
"model_mode": args.model_mode,
|
||||
"model_name": args.model_name,
|
||||
"has_context": args.has_context,
|
||||
}
|
||||
return AdvancedPromptTemplateService.get_prompt(prompt_args)
|
||||
|
||||
@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService
|
||||
|
||||
|
||||
def _build_backstage_input_url(form_token: str | None) -> str | None:
|
||||
@ -214,7 +214,11 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
if args_model.status is not None:
|
||||
args["status"] = args_model.status
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
@ -356,7 +360,11 @@ class WorkflowRunListApi(Resource):
|
||||
Get workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
if args_model.status is not None:
|
||||
args["status"] = args_model.status
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
|
||||
@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource):
|
||||
|
||||
node_id = args.node_id
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = session.scalar(
|
||||
select(WorkflowWebhookTrigger)
|
||||
@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -42,12 +45,13 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@ -91,10 +95,12 @@ class LoginApi(Resource):
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
|
||||
if is_login_error_rate_limit:
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED)
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invite_token = args.invite_token
|
||||
@ -110,14 +116,20 @@ class LoginApi(Resource):
|
||||
invitee_email = data.get("email") if data else None
|
||||
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
|
||||
if invitee_email_normalized != normalized_email:
|
||||
_log_console_login_failure(
|
||||
email=normalized_email,
|
||||
reason=LoginFailureReason.INVALID_INVITATION_EMAIL,
|
||||
)
|
||||
raise InvalidEmailError()
|
||||
account = _authenticate_account_with_case_fallback(
|
||||
request_email, normalized_email, args.password, invite_token
|
||||
)
|
||||
except services.errors.account.AccountLoginError:
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError as exc:
|
||||
AccountService.add_login_error_rate_limit(normalized_email)
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError() from exc
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -240,20 +252,27 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != user_email:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args.code:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = _get_account_with_case_fallback(original_email)
|
||||
except Unauthorized as exc:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError() from exc
|
||||
except AccountRegisterError:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -279,6 +298,7 @@ class EmailCodeLoginApi(Resource):
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
@ -336,3 +356,12 @@ def _authenticate_account_with_case_fallback(
|
||||
if original_email == normalized_email:
|
||||
raise
|
||||
return AccountService.authenticate(normalized_email, password, invite_token)
|
||||
|
||||
|
||||
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
logger.warning(
|
||||
"Console login failed: email=%s reason=%s ip_address=%s",
|
||||
email,
|
||||
reason,
|
||||
extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TypedDict
|
||||
|
||||
from flask import request
|
||||
@ -13,6 +14,14 @@ from services.billing_service import BillingService
|
||||
_FALLBACK_LANG = "en-US"
|
||||
|
||||
|
||||
class NotificationLangContent(TypedDict, total=False):
|
||||
lang: str
|
||||
title: str
|
||||
subtitle: str
|
||||
body: str
|
||||
titlePicUrl: str
|
||||
|
||||
|
||||
class NotificationItemDict(TypedDict):
|
||||
notification_id: str | None
|
||||
frequency: str | None
|
||||
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
|
||||
notifications: list[NotificationItemDict]
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
return (
|
||||
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
|
||||
)
|
||||
|
||||
|
||||
class DismissNotificationPayload(BaseModel):
|
||||
@ -71,7 +82,7 @@ class NotificationApi(Resource):
|
||||
|
||||
notifications: list[NotificationItemDict] = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
item: NotificationItemDict = {
|
||||
"notification_id": notification.get("notificationId"),
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytz
|
||||
from flask import request
|
||||
@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload)
|
||||
register_schema_models(console_ns, AccountResponse)
|
||||
|
||||
|
||||
def _serialize_account(account) -> dict:
|
||||
def _serialize_account(account) -> dict[str, Any]:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from models.account import AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
from services.operation_service import OperationService
|
||||
from services.operation_service import OperationService, UtmInfo
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
utm_info_dict: UtmInfo = json.loads(utm_info)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Any, Union
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
|
||||
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
|
||||
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
|
||||
"""Convert raw user input form to VariableEntity objects"""
|
||||
return [self._create_variable_entity(item) for item in raw_form]
|
||||
|
||||
def _create_variable_entity(self, item: dict) -> VariableEntity:
|
||||
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
|
||||
"""Create a single VariableEntity from raw form item"""
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
|
||||
try:
|
||||
variable_type = VariableEntityType(variable_type_raw)
|
||||
except ValueError as e:
|
||||
raise MCPRequestError(
|
||||
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
|
||||
) from e
|
||||
variable = item[variable_type_raw]
|
||||
|
||||
return VariableEntity(
|
||||
type=variable_type,
|
||||
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
|
||||
json_schema=variable.get("json_schema"),
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
"""Parse and validate MCP request"""
|
||||
try:
|
||||
return mcp_types.ClientRequest.model_validate(args)
|
||||
|
||||
@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict = {}
|
||||
summaries: dict[str, str | None] = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result = []
|
||||
result: list[dict[str, Any]] = []
|
||||
for segment in segments:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
|
||||
@ -5,6 +5,7 @@ Web App Human Input Form APIs.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -58,10 +59,19 @@ def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class FormDefinitionPayload(TypedDict):
|
||||
form_content: Any
|
||||
inputs: Any
|
||||
resolved_default_values: dict[str, str]
|
||||
user_actions: Any
|
||||
expiration_time: int
|
||||
site: NotRequired[dict]
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
"""Return the form payload (optionally with site) as a JSON response."""
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
payload: FormDefinitionPayload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from jwt import InvalidTokenError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -20,7 +23,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import decode_jwt_token
|
||||
from libs.helper import EmailStr
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
@ -29,9 +32,11 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@field_validator("password")
|
||||
@ -76,14 +81,18 @@ class LoginApi(Resource):
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
payload = LoginPayload.model_validate(web_ns.payload or {})
|
||||
normalized_email = payload.email.lower()
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(payload.email, payload.password)
|
||||
except services.errors.account.AccountLoginError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
@ -212,21 +221,30 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
|
||||
if token_data is None:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
if normalized_token_email != user_email:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != payload.code:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(payload.token)
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
try:
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
except Unauthorized as exc:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError() from exc
|
||||
if not account:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
@ -234,3 +252,12 @@ class EmailCodeLoginApi(Resource):
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
||||
|
||||
def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
logger.warning(
|
||||
"Web login failed: email=%s reason=%s ip_address=%s",
|
||||
email,
|
||||
reason,
|
||||
extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
@ -103,21 +104,23 @@ class PassportResource(Resource):
|
||||
return response
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
|
||||
"""
|
||||
Decode the enterprise user session from the Authorization header.
|
||||
"""
|
||||
if not jwt_token:
|
||||
return None
|
||||
|
||||
decoded = PassportService().verify(jwt_token)
|
||||
decoded: dict[str, Any] = PassportService().verify(jwt_token)
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "webapp_login_token":
|
||||
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
|
||||
return decoded
|
||||
|
||||
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
|
||||
def exchange_token_for_existing_web_user(
|
||||
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
|
||||
):
|
||||
"""
|
||||
Exchange a token for an existing web user session.
|
||||
"""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from sqlalchemy import select
|
||||
@ -113,12 +113,12 @@ class AppSiteInfo:
|
||||
}
|
||||
|
||||
|
||||
def serialize_site(site: Site) -> dict:
|
||||
def serialize_site(site: Site) -> dict[str, Any]:
|
||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
||||
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
||||
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
||||
@ -138,7 +138,9 @@ class DatasetConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for dataset feature
|
||||
|
||||
@ -172,7 +174,7 @@ class DatasetConfigManager:
|
||||
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
|
||||
|
||||
@classmethod
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
|
||||
"""
|
||||
Extract dataset config for legacy compatibility
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ class ModelConfigManager:
|
||||
return dict(config), ["model"]
|
||||
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict):
|
||||
def validate_model_completion_params(cls, cp: dict[str, Any]):
|
||||
# model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate pre_prompt and set defaults for prompt feature
|
||||
depending on the config['model']
|
||||
@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
|
||||
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
|
||||
|
||||
@classmethod
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict):
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
|
||||
"""
|
||||
Validate post_prompt and set defaults for prompt feature
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
|
||||
return variable_entities, external_data_variables
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
|
||||
return config, related_config_keys
|
||||
|
||||
@classmethod
|
||||
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@ -164,7 +164,9 @@ class BasicVariablesConfigManager:
|
||||
return config, ["user_input_form"]
|
||||
|
||||
@classmethod
|
||||
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_external_data_tools_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for external data fetch feature
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ class FileUploadConfigManager:
|
||||
return FileUploadConfig.model_validate(file_upload_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for file upload feature
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
@ -13,7 +15,7 @@ class AppConfigModel(BaseModel):
|
||||
|
||||
class MoreLikeThisConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -23,7 +25,7 @@ class MoreLikeThisConfigManager:
|
||||
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError:
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class OpeningStatementConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[str, list]:
|
||||
def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class OpeningStatementConfigManager:
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for opening statement feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RetrievalResourceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = config.get("retriever_resource")
|
||||
if retriever_resource_dict:
|
||||
@ -10,7 +13,7 @@ class RetrievalResourceConfigManager:
|
||||
return show_retrieve_source
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for retriever resource feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SpeechToTextConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class SpeechToTextConfigManager:
|
||||
return speech_to_text
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for speech to text feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
return suggested_questions_after_answer
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for suggested questions feature
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import TextToSpeechEntity
|
||||
|
||||
|
||||
class TextToSpeechConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict):
|
||||
def convert(cls, config: dict[str, Any]):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -22,7 +24,7 @@ class TextToSpeechConfigManager:
|
||||
return text_to_speech
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for text to speech feature
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||
@ -34,7 +36,9 @@ class PipelineConfigManager(BaseAppConfigManager):
|
||||
return pipeline_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
def config_validate(
|
||||
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate for pipeline config
|
||||
|
||||
|
||||
@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user_id: str,
|
||||
all_files: list,
|
||||
datasource_info: Mapping[str, Any],
|
||||
next_page_parameters: dict | None = None,
|
||||
next_page_parameters: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Get files in a folder.
|
||||
|
||||
@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
extras: dict[str, Any] | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
extras: dict[str, Any] | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
|
||||
@ -129,7 +129,7 @@ class DatasourceEntity(BaseModel):
|
||||
identity: DatasourceIdentity
|
||||
parameters: list[DatasourceParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The label of the datasource")
|
||||
output_schema: dict | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
@ -192,7 +192,7 @@ class DatasourceInvokeMeta(BaseModel):
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: str | None = None
|
||||
tool_config: dict | None = None
|
||||
tool_config: dict[str, Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> DatasourceInvokeMeta:
|
||||
@ -242,7 +242,7 @@ class OnlineDocumentPage(BaseModel):
|
||||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
page_name: str = Field(..., description="The page title")
|
||||
page_icon: dict | None = Field(None, description="The page icon")
|
||||
page_icon: dict[str, Any] | None = Field(None, description="The page icon")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
last_edited_time: str = Field(..., description="The last edited time")
|
||||
parent_id: str | None = Field(None, description="The parent page id")
|
||||
@ -301,7 +301,7 @@ class GetWebsiteCrawlRequest(BaseModel):
|
||||
Get website crawl request
|
||||
"""
|
||||
|
||||
crawl_parameters: dict = Field(..., description="The crawl parameters")
|
||||
crawl_parameters: dict[str, Any] = Field(..., description="The crawl parameters")
|
||||
|
||||
|
||||
class WebSiteInfoDetail(BaseModel):
|
||||
@ -358,7 +358,7 @@ class OnlineDriveFileBucket(BaseModel):
|
||||
bucket: str | None = Field(None, description="The file bucket")
|
||||
files: list[OnlineDriveFile] = Field(..., description="The file list")
|
||||
is_truncated: bool = Field(False, description="Whether the result is truncated")
|
||||
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||
@ -369,7 +369,7 @@ class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||
bucket: str | None = Field(None, description="The file bucket")
|
||||
prefix: str = Field(..., description="The parent folder ID")
|
||||
max_keys: int = Field(20, description="Page size for pagination")
|
||||
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesResponse(BaseModel):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
@ -37,7 +39,7 @@ class PipelineDocument(BaseModel):
|
||||
id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: dict | None = None
|
||||
data_source_info: dict[str, Any] | None = None
|
||||
name: str
|
||||
indexing_status: str
|
||||
error: str | None = None
|
||||
|
||||
@ -6,6 +6,7 @@ import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get current credentials.
|
||||
|
||||
@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
|
||||
def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a specific provider credential by ID.
|
||||
:param credential_id: Credential ID
|
||||
@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = stmt.where(ProviderCredential.id != exclude_id)
|
||||
return session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
|
||||
def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get provider credentials.
|
||||
|
||||
@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
|
||||
def validate_provider_credentials(
|
||||
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
|
||||
):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
return provider_names
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||
def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
:param credentials: provider credentials
|
||||
@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def update_provider_credential(
|
||||
self,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str,
|
||||
credential_name: str | None,
|
||||
):
|
||||
@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def _get_specific_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credential_id: str
|
||||
) -> dict | None:
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a specific provider credential by ID.
|
||||
:param credential_id: Credential ID
|
||||
@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
|
||||
return session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
|
||||
def get_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credential_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get custom model credentials.
|
||||
|
||||
@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str = "",
|
||||
session: Session | None = None,
|
||||
):
|
||||
@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return _validate(new_session)
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create a custom model credential.
|
||||
@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel):
|
||||
raise
|
||||
|
||||
def update_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_name: str | None,
|
||||
credential_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Update a custom model credential.
|
||||
@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel):
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
|
||||
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
|
||||
def get_model_schema(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
|
||||
def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
|
||||
"""
|
||||
Obfuscated credentials.
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum, auto
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
|
||||
enabled: bool
|
||||
current_quota_type: ProviderQuotaType | None = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
credentials: dict | None = None
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CustomProviderConfiguration(BaseModel):
|
||||
@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
|
||||
Model class for provider custom configuration.
|
||||
"""
|
||||
|
||||
credentials: dict
|
||||
credentials: dict[str, Any]
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_credentials: list[CredentialConfiguration] = []
|
||||
@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict | None
|
||||
credentials: dict[str, Any] | None
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -15,7 +16,7 @@ class ProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
|
||||
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
def get(self) -> dict | None:
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
@ -33,7 +34,7 @@ class ProviderCredentialsCache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, credentials: dict):
|
||||
def set(self, credentials: dict[str, Any]):
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC):
|
||||
"""Generate cache key based on subclass implementation"""
|
||||
pass
|
||||
|
||||
def get(self) -> dict | None:
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
"""Get cached provider credentials"""
|
||||
cached_credentials = redis_client.get(self.cache_key)
|
||||
if cached_credentials:
|
||||
@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||
class NoOpProviderCredentialCache:
|
||||
"""No-op provider credential cache"""
|
||||
|
||||
def get(self) -> dict | None:
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
"""Get cached provider credentials"""
|
||||
return None
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -18,7 +19,7 @@ class ToolParameterCache:
|
||||
f":identity_id:{identity_id}"
|
||||
)
|
||||
|
||||
def get(self) -> dict | None:
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
@ -36,7 +37,7 @@ class ToolParameterCache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, parameters: dict):
|
||||
def set(self, parameters: dict[str, Any]):
|
||||
"""Cache model provider credentials."""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
|
||||
|
||||
|
||||
@ -115,7 +115,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[True] = True,
|
||||
@ -126,7 +126,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False] = False,
|
||||
@ -137,7 +137,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
@ -147,7 +147,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
@ -528,7 +528,7 @@ class LBModelManager:
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
||||
managed_credentials: dict | None = None,
|
||||
managed_credentials: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Load balancing model manager
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -10,7 +12,7 @@ from models.api_based_extension import APIBasedExtension
|
||||
|
||||
class ModerationInputParams(BaseModel):
|
||||
app_id: str = ""
|
||||
inputs: dict = Field(default_factory=dict)
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
query: str = ""
|
||||
|
||||
|
||||
@ -23,7 +25,7 @@ class ApiModeration(Moderation):
|
||||
name: str = "api"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -41,7 +43,7 @@ class ApiModeration(Moderation):
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -73,7 +75,7 @@ class ApiModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]):
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -15,7 +16,7 @@ class ModerationInputsResult(BaseModel):
|
||||
flagged: bool = False
|
||||
action: ModerationAction
|
||||
preset_response: str = ""
|
||||
inputs: dict = Field(default_factory=dict)
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
query: str = ""
|
||||
|
||||
|
||||
@ -33,13 +34,13 @@ class Moderation(Extensible, ABC):
|
||||
|
||||
module: ExtensionModule = ExtensionModule.MODERATION
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, config: dict | None = None):
|
||||
def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None):
|
||||
super().__init__(tenant_id, config)
|
||||
self.app_id = app_id
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -50,7 +51,7 @@ class Moderation(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
@ -75,7 +76,7 @@ class Moderation(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool):
|
||||
def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool):
|
||||
# inputs_config
|
||||
inputs_config = config.get("inputs_config")
|
||||
if not isinstance(inputs_config, dict):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
|
||||
from extensions.ext_code_based_extension import code_based_extension
|
||||
@ -6,12 +8,12 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
class ModerationFactory:
|
||||
__extension_instance: Moderation
|
||||
|
||||
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict):
|
||||
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]):
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
self.__extension_instance = extension_class(app_id, tenant_id, config)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict):
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -24,7 +26,7 @@ class ModerationFactory:
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
|
||||
@ -8,7 +8,7 @@ class KeywordsModeration(Moderation):
|
||||
name: str = "keywords"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -28,7 +28,7 @@ class KeywordsModeration(Moderation):
|
||||
if len(keywords_row_len) > 100:
|
||||
raise ValueError("the number of rows for the keywords must be less than 100")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -66,7 +66,7 @@ class KeywordsModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||
def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool:
|
||||
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||
|
||||
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
@ -8,7 +10,7 @@ class OpenAIModeration(Moderation):
|
||||
name: str = "openai_moderation"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -18,7 +20,7 @@ class OpenAIModeration(Moderation):
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -49,7 +51,7 @@ class OpenAIModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict):
|
||||
def _is_violated(self, inputs: dict[str, Any]):
|
||||
text = "\n".join(str(inputs.values()))
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
|
||||
@ -778,7 +778,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
def _construct_llm_attributes(self, prompts: dict[str, Any] | list[Any] | str | None) -> dict[str, str]:
|
||||
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
|
||||
attributes: dict[str, str] = {}
|
||||
|
||||
@ -797,7 +797,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
|
||||
set_attribute(path, value)
|
||||
|
||||
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
|
||||
def set_tool_call_attributes(
|
||||
message_index: int, tool_index: int, tool_call: dict[str, Any] | object | None
|
||||
) -> None:
|
||||
"""Extract and assign tool call details safely."""
|
||||
if not tool_call:
|
||||
return
|
||||
|
||||
@ -242,7 +242,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
|
||||
return inputs, attributes
|
||||
|
||||
def _parse_knowledge_retrieval_outputs(self, outputs: dict):
|
||||
def _parse_knowledge_retrieval_outputs(self, outputs: dict[str, Any]):
|
||||
"""Parse KR outputs and attributes from KR workflow node"""
|
||||
retrieved = outputs.get("result", [])
|
||||
|
||||
@ -319,7 +319,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def _get_message_user_id(self, metadata: dict) -> str | None:
|
||||
def _get_message_user_id(self, metadata: dict[str, Any]) -> str | None:
|
||||
if (end_user_id := metadata.get("from_end_user_id")) and (
|
||||
end_user_data := db.session.get(EndUser, end_user_id)
|
||||
):
|
||||
@ -468,7 +468,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
}
|
||||
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
|
||||
|
||||
def _set_trace_metadata(self, span: Span, metadata: dict):
|
||||
def _set_trace_metadata(self, span: Span, metadata: dict[str, Any]):
|
||||
token = None
|
||||
try:
|
||||
# NB: Set span in context such that we can use update_current_trace() API
|
||||
@ -490,7 +490,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
return messages
|
||||
return prompts # Fallback to original format
|
||||
|
||||
def _parse_single_message(self, item: dict):
|
||||
def _parse_single_message(self, item: dict[str, Any]):
|
||||
"""Postprocess single message format to be standard chat message"""
|
||||
role = item.get("role", "user")
|
||||
msg = {"role": role, "content": item.get("text", "")}
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from opik import Opik, Trace
|
||||
@ -436,7 +436,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def add_trace(self, opik_trace_data: dict) -> Trace:
|
||||
def add_trace(self, opik_trace_data: dict[str, Any]) -> Trace:
|
||||
try:
|
||||
trace = self.opik_client.trace(**opik_trace_data)
|
||||
logger.debug("Opik Trace created successfully")
|
||||
@ -444,7 +444,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, opik_span_data: dict):
|
||||
def add_span(self, opik_span_data: dict[str, Any]):
|
||||
try:
|
||||
self.opik_client.span(**opik_span_data)
|
||||
logger.debug("Opik Span created successfully")
|
||||
|
||||
@ -324,7 +324,7 @@ class OpsTraceManager:
|
||||
|
||||
@classmethod
|
||||
def encrypt_tracing_config(
|
||||
cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
|
||||
cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any], current_trace_config=None
|
||||
):
|
||||
"""
|
||||
Encrypt tracing config.
|
||||
@ -363,7 +363,7 @@ class OpsTraceManager:
|
||||
return encrypted_config.model_dump()
|
||||
|
||||
@classmethod
|
||||
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
|
||||
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any]):
|
||||
"""
|
||||
Decrypt tracing config
|
||||
:param tenant_id: tenant id
|
||||
@ -408,7 +408,7 @@ class OpsTraceManager:
|
||||
return dict(decrypted_config)
|
||||
|
||||
@classmethod
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict[str, Any]):
|
||||
"""
|
||||
Decrypt tracing config
|
||||
:param tracing_provider: tracing provider
|
||||
@ -581,7 +581,7 @@ class OpsTraceManager:
|
||||
return app_trace_config
|
||||
|
||||
@staticmethod
|
||||
def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
|
||||
def check_trace_config_is_effective(tracing_config: dict[str, Any], tracing_provider: str):
|
||||
"""
|
||||
Check trace config is effective
|
||||
:param tracing_config: tracing config
|
||||
@ -596,7 +596,7 @@ class OpsTraceManager:
|
||||
return trace_instance(config).api_check()
|
||||
|
||||
@staticmethod
|
||||
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
|
||||
def get_trace_config_project_key(tracing_config: dict[str, Any], tracing_provider: str):
|
||||
"""
|
||||
get trace config is project key
|
||||
:param tracing_config: tracing config
|
||||
@ -611,7 +611,7 @@ class OpsTraceManager:
|
||||
return trace_instance(config).get_project_key()
|
||||
|
||||
@staticmethod
|
||||
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
|
||||
def get_trace_config_project_url(tracing_config: dict[str, Any], tracing_provider: str):
|
||||
"""
|
||||
get trace config is project key
|
||||
:param tracing_config: tracing config
|
||||
@ -1322,8 +1322,8 @@ class TraceTask:
|
||||
error=error,
|
||||
)
|
||||
|
||||
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
|
||||
node_data: dict = kwargs.get("node_execution_data", {})
|
||||
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict[str, Any]:
|
||||
node_data: dict[str, Any] = kwargs.get("node_execution_data", {})
|
||||
if not node_data:
|
||||
return {}
|
||||
|
||||
@ -1431,7 +1431,7 @@ class TraceTask:
|
||||
return node_trace
|
||||
return DraftNodeExecutionTrace(**node_trace.model_dump())
|
||||
|
||||
def _extract_streaming_metrics(self, message_data) -> dict:
|
||||
def _extract_streaming_metrics(self, message_data) -> dict[str, Any]:
|
||||
if not message_data.message_metadata:
|
||||
return {}
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -31,7 +32,7 @@ class EndpointEntity(BasePluginEntity):
|
||||
entity of an endpoint
|
||||
"""
|
||||
|
||||
settings: dict
|
||||
settings: dict[str, Any]
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
expired_at: datetime
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||
|
||||
@ -40,7 +42,7 @@ class MarketplacePluginDeclaration(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_declaration(cls, data: dict):
|
||||
def transform_declaration(cls, data: dict[str, Any]) -> dict[str, Any]:
|
||||
if "endpoint" in data and not data["endpoint"]:
|
||||
del data["endpoint"]
|
||||
if "model" in data and not data["model"]:
|
||||
|
||||
@ -123,7 +123,7 @@ class PluginDeclaration(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_category(cls, values: dict):
|
||||
def validate_category(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
# auto detect category
|
||||
if values.get("tool"):
|
||||
values["category"] = PluginCategory.Tool
|
||||
|
||||
@ -73,7 +73,7 @@ class PluginBasicBooleanResponse(BaseModel):
|
||||
"""
|
||||
|
||||
result: bool
|
||||
credentials: dict | None = None
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PluginModelSchemaEntity(BaseModel):
|
||||
|
||||
@ -49,7 +49,7 @@ class RequestInvokeTool(BaseModel):
|
||||
tool_type: Literal["builtin", "workflow", "api", "mcp"]
|
||||
provider: str
|
||||
tool: str
|
||||
tool_parameters: dict
|
||||
tool_parameters: dict[str, Any]
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
@ -209,7 +209,7 @@ class RequestInvokeEncrypt(BaseModel):
|
||||
opt: Literal["encrypt", "decrypt", "clear"]
|
||||
namespace: Literal["endpoint"]
|
||||
identity: str
|
||||
data: dict = Field(default_factory=dict)
|
||||
data: dict[str, Any] = Field(default_factory=dict)
|
||||
config: list[BasicProviderConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
if json_response.get("data"):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
@ -68,7 +68,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
if json_response.get("data"):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
@ -110,7 +110,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
tool_provider_id = DatasourceProviderID(provider_id)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for datasource in data.get("declaration", {}).get("datasources", []):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.endpoint import EndpointEntityWithInstance
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.impl.exc import PluginDaemonInternalServerError
|
||||
@ -5,7 +7,12 @@ from core.plugin.impl.exc import PluginDaemonInternalServerError
|
||||
|
||||
class PluginEndpointClient(BasePluginClient):
|
||||
def create_endpoint(
|
||||
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
name: str,
|
||||
settings: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Create an endpoint for the given plugin.
|
||||
@ -49,7 +56,9 @@ class PluginEndpointClient(BasePluginClient):
|
||||
params={"plugin_id": plugin_id, "page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
|
||||
def update_endpoint(
|
||||
self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Update the settings of the given endpoint.
|
||||
"""
|
||||
|
||||
@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
@ -80,7 +80,7 @@ class PluginModelClient(BasePluginClient):
|
||||
return None
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
|
||||
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
input_type: str,
|
||||
) -> EmbeddingResult:
|
||||
@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
documents: list[dict],
|
||||
input_type: str,
|
||||
) -> EmbeddingResult:
|
||||
@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
) -> list[int]:
|
||||
"""
|
||||
@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Generator[bytes, None, None]:
|
||||
@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
language: str | None = None,
|
||||
):
|
||||
"""
|
||||
@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
file: IO[bytes],
|
||||
) -> str:
|
||||
"""
|
||||
@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
text: str,
|
||||
) -> bool:
|
||||
"""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
@ -263,7 +264,7 @@ class PluginInstaller(BasePluginClient):
|
||||
original_plugin_unique_identifier: str,
|
||||
new_plugin_unique_identifier: str,
|
||||
source: PluginInstallationSource,
|
||||
meta: dict,
|
||||
meta: dict[str, Any],
|
||||
) -> PluginInstallTaskStartResponse:
|
||||
"""
|
||||
Upgrade a plugin.
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import csv
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@ -23,7 +24,7 @@ class CSVExtractor(BaseExtractor):
|
||||
encoding: str | None = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: str | None = None,
|
||||
csv_args: dict | None = None,
|
||||
csv_args: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
@ -54,8 +54,8 @@ class BaseAPIClient:
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
query_params: dict | None = None,
|
||||
data: dict | None = None,
|
||||
query_params: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> Response:
|
||||
stream = kwargs.pop("stream", False)
|
||||
@ -66,19 +66,25 @@ class BaseAPIClient:
|
||||
|
||||
return self.session.request(method, url, params=query_params, json=data, **kwargs)
|
||||
|
||||
def _get(self, endpoint: str, query_params: dict | None = None, **kwargs):
|
||||
def _get(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs):
|
||||
return self._request("GET", endpoint, query_params=query_params, **kwargs)
|
||||
|
||||
def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
def _post(
|
||||
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
|
||||
):
|
||||
return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs)
|
||||
|
||||
def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
def _put(
|
||||
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
|
||||
):
|
||||
return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs)
|
||||
|
||||
def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs):
|
||||
def _delete(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs):
|
||||
return self._request("DELETE", endpoint, query_params=query_params, **kwargs)
|
||||
|
||||
def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
def _patch(
|
||||
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
|
||||
):
|
||||
return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs)
|
||||
|
||||
|
||||
@ -99,7 +105,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def process_response(self, response: Response) -> dict | bytes | list | None | Generator:
|
||||
def process_response(self, response: Response) -> dict[str, Any] | bytes | list[Any] | None | Generator:
|
||||
if response.status_code == 401:
|
||||
raise WaterCrawlAuthenticationError(response)
|
||||
|
||||
@ -186,7 +192,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
||||
yield from generator
|
||||
|
||||
def get_crawl_request_results(
|
||||
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict | None = None
|
||||
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict[str, Any] | None = None
|
||||
):
|
||||
query_params = query_params or {}
|
||||
query_params.update({"page": page or 1, "page_size": page_size or 25})
|
||||
@ -210,7 +216,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
||||
if event_data["type"] == "result":
|
||||
return event_data["data"]
|
||||
|
||||
def download_result(self, result_object: dict):
|
||||
def download_result(self, result_object: dict[str, Any]):
|
||||
response = httpx.get(result_object["result"], timeout=None)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
|
||||
@ -120,7 +120,7 @@ class WaterCrawlProvider:
|
||||
}
|
||||
|
||||
def _get_results(
|
||||
self, crawl_request_id: str, query_params: dict | None = None
|
||||
self, crawl_request_id: str, query_params: dict[str, Any] | None = None
|
||||
) -> Generator[WatercrawlDocumentData, None, None]:
|
||||
page = 0
|
||||
page_size = 100
|
||||
|
||||
@ -875,7 +875,11 @@ class DatasetRetrieval:
|
||||
return retrieval_resource_list
|
||||
|
||||
def _on_retrieval_end(
|
||||
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
|
||||
self,
|
||||
flask_app: Flask,
|
||||
documents: list[Document],
|
||||
message_id: str | None = None,
|
||||
timer: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Handle retrieval end."""
|
||||
with flask_app.app_context():
|
||||
@ -980,7 +984,7 @@ class DatasetRetrieval:
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
|
||||
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
|
||||
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict[str, Any] | None):
|
||||
"""Send trace task if trace manager is available."""
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
@ -1142,7 +1146,7 @@ class DatasetRetrieval:
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
) -> list[DatasetRetrieverBaseTool] | None:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
@ -1337,7 +1341,7 @@ class DatasetRetrieval:
|
||||
metadata_filtering_mode: str,
|
||||
metadata_model_config: ModelConfig,
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
|
||||
document_query = select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||
@ -1417,7 +1421,7 @@ class DatasetRetrieval:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||
return metadata_filter_document_ids, metadata_condition
|
||||
|
||||
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
|
||||
def _replace_metadata_filter_value(self, text: str, inputs: dict[str, Any]) -> str:
|
||||
if not inputs:
|
||||
return text
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Generator
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from mimetypes import guess_extension
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
@ -50,7 +51,7 @@ def safe_json_value(v):
|
||||
return v
|
||||
|
||||
|
||||
def safe_json_dict(d: dict):
|
||||
def safe_json_dict(d: dict[str, Any]):
|
||||
if not isinstance(d, dict):
|
||||
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
|
||||
return {k: safe_json_value(v) for k, v in d.items()}
|
||||
@ -196,11 +197,11 @@ class ToolFileMessageTransformer:
|
||||
|
||||
@staticmethod
|
||||
def _with_tool_file_meta(
|
||||
meta: dict | None,
|
||||
meta: dict[str, Any] | None,
|
||||
*,
|
||||
tool_file_id: str | None = None,
|
||||
url: str | None = None,
|
||||
) -> dict:
|
||||
) -> dict[str, Any]:
|
||||
normalized_meta = meta.copy() if meta is not None else {}
|
||||
resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url)
|
||||
if resolved_tool_file_id and "tool_file_id" not in normalized_meta:
|
||||
|
||||
@ -32,7 +32,7 @@ class OpenAPISpecDict(TypedDict):
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None
|
||||
openapi: Mapping[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
@ -236,7 +236,7 @@ class ApiBasedToolSchemaParser:
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
def _get_tool_parameter_type(parameter: dict[str, Any]) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
typ: str | None = None
|
||||
if parameter.get("format") == "binary":
|
||||
@ -265,7 +265,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
yaml: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
@ -278,14 +278,14 @@ class ApiBasedToolSchemaParser:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
openapi: dict[str, Any] = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(
|
||||
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
swagger: dict[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
|
||||
) -> OpenAPISpecDict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
@ -351,7 +351,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
json: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
@ -392,7 +392,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
content: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
|
||||
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import copy
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
@ -15,9 +16,18 @@ from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class AdvancedPromptTemplateArgs(TypedDict):
|
||||
"""Expected shape of the args dict passed to AdvancedPromptTemplateService.get_prompt."""
|
||||
|
||||
app_mode: str
|
||||
model_mode: str
|
||||
model_name: str
|
||||
has_context: str
|
||||
|
||||
|
||||
class AdvancedPromptTemplateService:
|
||||
@classmethod
|
||||
def get_prompt(cls, args: dict):
|
||||
def get_prompt(cls, args: AdvancedPromptTemplateArgs) -> dict[str, Any]:
|
||||
app_mode = args["app_mode"]
|
||||
model_mode = args["model_mode"]
|
||||
model_name = args["model_name"]
|
||||
@ -29,7 +39,7 @@ class AdvancedPromptTemplateService:
|
||||
return cls.get_common_prompt(app_mode, model_mode, has_context)
|
||||
|
||||
@classmethod
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]:
|
||||
context_prompt = copy.deepcopy(CONTEXT)
|
||||
|
||||
match app_mode:
|
||||
@ -63,7 +73,7 @@ class AdvancedPromptTemplateService:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str):
|
||||
def get_completion_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]:
|
||||
if has_context == "true":
|
||||
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
|
||||
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
@ -72,7 +82,7 @@ class AdvancedPromptTemplateService:
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str):
|
||||
def get_chat_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]:
|
||||
if has_context == "true":
|
||||
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
|
||||
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
@ -81,7 +91,7 @@ class AdvancedPromptTemplateService:
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]:
|
||||
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
|
||||
|
||||
match app_mode:
|
||||
|
||||
@ -233,7 +233,7 @@ class DatasetService:
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model_name: str | None = None,
|
||||
retrieval_model: RetrievalModel | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: dict[str, Any] | None = None,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)):
|
||||
@ -2493,7 +2493,7 @@ class DocumentService:
|
||||
data_source_type: str,
|
||||
document_form: str,
|
||||
document_language: str,
|
||||
data_source_info: dict,
|
||||
data_source_info: dict[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
account: Account,
|
||||
@ -2850,7 +2850,7 @@ class DocumentService:
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
@classmethod
|
||||
def estimate_args_validate(cls, args: dict):
|
||||
def estimate_args_validate(cls, args: dict[str, Any]):
|
||||
if "info_list" not in args or not args["info_list"]:
|
||||
raise ValueError("Data source info is required")
|
||||
|
||||
@ -3132,7 +3132,7 @@ class DocumentService:
|
||||
|
||||
class SegmentService:
|
||||
@classmethod
|
||||
def segment_create_args_validate(cls, args: dict, document: Document):
|
||||
def segment_create_args_validate(cls, args: dict[str, Any], document: Document):
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
if "answer" not in args or not args["answer"]:
|
||||
raise ValueError("Answer is required")
|
||||
@ -3149,7 +3149,7 @@ class SegmentService:
|
||||
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
|
||||
|
||||
@classmethod
|
||||
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
||||
def create_segment(cls, args: dict[str, Any], document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
|
||||
@ -1,9 +1,25 @@
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from libs.helper import EmailStr
|
||||
from libs.password import valid_password
|
||||
|
||||
|
||||
class LoginFailureReason(StrEnum):
|
||||
"""Bounded reason codes for failed login audit logs."""
|
||||
|
||||
ACCOUNT_BANNED = auto()
|
||||
ACCOUNT_IN_FREEZE = auto()
|
||||
ACCOUNT_NOT_FOUND = auto()
|
||||
EMAIL_CODE_EMAIL_MISMATCH = auto()
|
||||
INVALID_CREDENTIALS = auto()
|
||||
INVALID_EMAIL_CODE = auto()
|
||||
INVALID_EMAIL_CODE_TOKEN = auto()
|
||||
INVALID_INVITATION_EMAIL = auto()
|
||||
LOGIN_RATE_LIMITED = auto()
|
||||
|
||||
|
||||
class LoginPayloadBase(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
@ -47,7 +47,7 @@ class ExternalDatasetService:
|
||||
return external_knowledge_apis.items, external_knowledge_apis.total
|
||||
|
||||
@classmethod
|
||||
def validate_api_list(cls, api_settings: dict):
|
||||
def validate_api_list(cls, api_settings: dict[str, Any]):
|
||||
if not api_settings:
|
||||
raise ValueError("api list is empty")
|
||||
if not api_settings.get("endpoint"):
|
||||
@ -56,7 +56,7 @@ class ExternalDatasetService:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
@staticmethod
|
||||
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis:
|
||||
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict[str, Any]) -> ExternalKnowledgeApis:
|
||||
settings = args.get("settings")
|
||||
if settings is None:
|
||||
raise ValueError("settings is required")
|
||||
@ -75,7 +75,7 @@ class ExternalDatasetService:
|
||||
return external_knowledge_api
|
||||
|
||||
@staticmethod
|
||||
def check_endpoint_and_api_key(settings: dict):
|
||||
def check_endpoint_and_api_key(settings: dict[str, Any]):
|
||||
if "endpoint" not in settings or not settings["endpoint"]:
|
||||
raise ValueError("endpoint is required")
|
||||
if "api_key" not in settings or not settings["api_key"]:
|
||||
@ -178,7 +178,9 @@ class ExternalDatasetService:
|
||||
return external_knowledge_binding
|
||||
|
||||
@staticmethod
|
||||
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
|
||||
def document_create_args_validate(
|
||||
tenant_id: str, external_knowledge_api_id: str, process_parameter: dict[str, Any]
|
||||
):
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
@ -222,7 +224,7 @@ class ExternalDatasetService:
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]:
|
||||
def assembling_headers(authorization: Authorization, headers: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
authorization = deepcopy(authorization)
|
||||
if headers:
|
||||
headers = deepcopy(headers)
|
||||
@ -248,11 +250,11 @@ class ExternalDatasetService:
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting:
|
||||
def get_external_knowledge_api_settings(settings: dict[str, Any]) -> ExternalKnowledgeApiSetting:
|
||||
return ExternalKnowledgeApiSetting.model_validate(settings)
|
||||
|
||||
@staticmethod
|
||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
|
||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict[str, Any]) -> Dataset:
|
||||
# check if dataset name already exists
|
||||
if db.session.scalar(
|
||||
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
|
||||
@ -304,7 +306,7 @@ class ExternalDatasetService:
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_parameters: dict,
|
||||
external_retrieval_parameters: dict[str, Any],
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
):
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
|
||||
@ -44,7 +44,7 @@ class HitTestingService:
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
retrieval_model: dict | None,
|
||||
retrieval_model: dict[str, Any] | None,
|
||||
external_retrieval_model: dict,
|
||||
attachment_ids: list | None = None,
|
||||
limit: int = 10,
|
||||
|
||||
@ -1,8 +1,22 @@
|
||||
import os
|
||||
from typing import TypedDict
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class UtmInfo(TypedDict, total=False):
|
||||
"""Expected shape of the utm_info dict passed to record_utm.
|
||||
|
||||
All fields are optional; missing keys default to an empty string.
|
||||
"""
|
||||
|
||||
utm_source: str
|
||||
utm_medium: str
|
||||
utm_campaign: str
|
||||
utm_content: str
|
||||
utm_term: str
|
||||
|
||||
|
||||
class OperationService:
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
||||
@ -17,7 +31,7 @@ class OperationService:
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def record_utm(cls, tenant_id: str, utm_info: dict):
|
||||
def record_utm(cls, tenant_id: str, utm_info: UtmInfo):
|
||||
params = {
|
||||
"tenant_id": tenant_id,
|
||||
"utm_source": utm_info.get("utm_source", ""),
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
@ -135,7 +137,7 @@ class OpsService:
|
||||
return trace_config_data.to_dict()
|
||||
|
||||
@classmethod
|
||||
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
|
||||
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict[str, Any]):
|
||||
"""
|
||||
Create tracing app config
|
||||
:param app_id: app id
|
||||
@ -210,7 +212,7 @@ class OpsService:
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
|
||||
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict[str, Any]):
|
||||
"""
|
||||
Update tracing app config
|
||||
:param app_id: app id
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.impl.endpoint import PluginEndpointClient
|
||||
|
||||
|
||||
class EndpointService:
|
||||
@classmethod
|
||||
def create_endpoint(cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict):
|
||||
def create_endpoint(
|
||||
cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict[str, Any]
|
||||
):
|
||||
return PluginEndpointClient().create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
@ -32,7 +36,7 @@ class EndpointService:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
|
||||
def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]):
|
||||
return PluginEndpointClient().update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
@ -13,21 +14,21 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json
|
||||
"""
|
||||
|
||||
builtin_data: dict | None = None
|
||||
builtin_data: dict[str, Any] | None = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
return PipelineTemplateType.BUILTIN
|
||||
|
||||
def get_pipeline_templates(self, language: str) -> dict:
|
||||
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
|
||||
result = self.fetch_pipeline_templates_from_builtin(language)
|
||||
return result
|
||||
|
||||
def get_pipeline_template_detail(self, template_id: str):
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
|
||||
result = self.fetch_pipeline_template_detail_from_builtin(template_id)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _get_builtin_data(cls) -> dict:
|
||||
def _get_builtin_data(cls) -> dict[str, Any]:
|
||||
"""
|
||||
Get builtin data.
|
||||
:return:
|
||||
@ -43,7 +44,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
return cls.builtin_data or {}
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict:
|
||||
def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch pipeline templates from builtin.
|
||||
:param language: language
|
||||
@ -53,7 +54,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
return builtin_data.get("pipeline_templates", {}).get(language, {})
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None:
|
||||
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch pipeline template detail from builtin.
|
||||
:param template_id: Template ID
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -13,12 +15,12 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
Retrieval recommended app from database
|
||||
"""
|
||||
|
||||
def get_pipeline_templates(self, language: str) -> dict:
|
||||
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
|
||||
return result
|
||||
|
||||
def get_pipeline_template_detail(self, template_id: str):
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
|
||||
result = self.fetch_pipeline_template_detail_from_db(template_id)
|
||||
return result
|
||||
|
||||
@ -26,7 +28,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
return PipelineTemplateType.CUSTOMIZED
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict:
|
||||
def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch pipeline templates from db.
|
||||
:param tenant_id: tenant id
|
||||
@ -53,7 +55,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch pipeline template detail from db.
|
||||
:param template_id: Template ID
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -12,11 +14,11 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
Retrieval pipeline template from database
|
||||
"""
|
||||
|
||||
def get_pipeline_templates(self, language: str) -> dict:
|
||||
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
|
||||
result = self.fetch_pipeline_templates_from_db(language)
|
||||
return result
|
||||
|
||||
def get_pipeline_template_detail(self, template_id: str):
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
|
||||
result = self.fetch_pipeline_template_detail_from_db(template_id)
|
||||
return result
|
||||
|
||||
@ -24,7 +26,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
return PipelineTemplateType.DATABASE
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_templates_from_db(cls, language: str) -> dict:
|
||||
def fetch_pipeline_templates_from_db(cls, language: str) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch pipeline templates from db.
|
||||
:param language: language
|
||||
@ -54,7 +56,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch pipeline template detail from db.
|
||||
:param pipeline_id: Pipeline ID
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PipelineTemplateRetrievalBase(ABC):
|
||||
"""Interface for pipeline template retrieval."""
|
||||
|
||||
@abstractmethod
|
||||
def get_pipeline_templates(self, language: str) -> dict:
|
||||
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@ -15,8 +16,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
Retrieval recommended app from dify official
|
||||
"""
|
||||
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
|
||||
result: dict | None
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
|
||||
result: dict[str, Any] | None
|
||||
try:
|
||||
result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
|
||||
except Exception as e:
|
||||
@ -24,7 +25,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
|
||||
return result
|
||||
|
||||
def get_pipeline_templates(self, language: str) -> dict:
|
||||
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
|
||||
try:
|
||||
result = self.fetch_pipeline_templates_from_dify_official(language)
|
||||
except Exception as e:
|
||||
@ -36,7 +37,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
return PipelineTemplateType.REMOTE
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict:
|
||||
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch pipeline template detail from dify official.
|
||||
|
||||
@ -53,11 +54,11 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
+ f" status_code: {response.status_code},"
|
||||
+ f" response: {response.text[:1000]}"
|
||||
)
|
||||
data: dict = response.json()
|
||||
data: dict[str, Any] = response.json()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict:
|
||||
def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch pipeline templates from dify official.
|
||||
:param language: language
|
||||
@ -69,6 +70,6 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}")
|
||||
|
||||
result: dict = response.json()
|
||||
result: dict[str, Any] = response.json()
|
||||
|
||||
return result
|
||||
|
||||
@ -92,7 +92,7 @@ class ApiToolManageService:
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(
|
||||
schema: str, extra_info: dict | None = None
|
||||
schema: str, extra_info: dict[str, Any] | None = None
|
||||
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
@ -109,8 +109,8 @@ class ApiToolManageService:
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
icon: dict[str, Any],
|
||||
credentials: dict[str, Any],
|
||||
schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
@ -244,8 +244,8 @@ class ApiToolManageService:
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
original_provider: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
icon: dict[str, Any],
|
||||
credentials: dict[str, Any],
|
||||
_schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
privacy_policy: str | None,
|
||||
@ -356,8 +356,8 @@ class ApiToolManageService:
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
credentials: dict[str, Any],
|
||||
parameters: dict[str, Any],
|
||||
schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
):
|
||||
|
||||
@ -147,7 +147,7 @@ class BuiltinToolManageService:
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credential_id: str,
|
||||
credentials: dict | None = None,
|
||||
credentials: dict[str, Any] | None = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""
|
||||
@ -177,7 +177,7 @@ class BuiltinToolManageService:
|
||||
)
|
||||
|
||||
original_credentials = encrypter.decrypt(db_provider.credentials)
|
||||
new_credentials: dict = {
|
||||
new_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
@ -216,7 +216,7 @@ class BuiltinToolManageService:
|
||||
api_type: CredentialType,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
expires_at: int = -1,
|
||||
name: str | None = None,
|
||||
):
|
||||
@ -657,7 +657,7 @@ class BuiltinToolManageService:
|
||||
def save_custom_oauth_client_params(
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
client_params: dict | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
enable_oauth_custom_client: bool | None = None,
|
||||
):
|
||||
"""
|
||||
|
||||
@ -69,7 +69,9 @@ class ToolTransformService:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity):
|
||||
def repack_provider(
|
||||
tenant_id: str, provider: dict[str, Any] | ToolProviderApiEntity | PluginDatasourceProviderEntity
|
||||
):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import delete, or_, select
|
||||
@ -35,7 +36,7 @@ class WorkflowToolManageService:
|
||||
workflow_app_id: str,
|
||||
name: str,
|
||||
label: str,
|
||||
icon: dict,
|
||||
icon: dict[str, Any],
|
||||
description: str,
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
@ -117,7 +118,7 @@ class WorkflowToolManageService:
|
||||
workflow_tool_id: str,
|
||||
name: str,
|
||||
label: str,
|
||||
icon: dict,
|
||||
icon: dict[str, Any],
|
||||
description: str,
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
|
||||
@ -91,7 +91,7 @@ class WebsiteCrawlApiRequest:
|
||||
return CrawlRequest(url=self.url, provider=self.provider, options=options)
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args: dict) -> WebsiteCrawlApiRequest:
|
||||
def from_args(cls, args: dict[str, Any]) -> WebsiteCrawlApiRequest:
|
||||
"""Create from Flask-RESTful parsed arguments."""
|
||||
provider = args.get("provider")
|
||||
url = args.get("url")
|
||||
@ -115,7 +115,7 @@ class WebsiteCrawlStatusApiRequest:
|
||||
job_id: str
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest:
|
||||
def from_args(cls, args: dict[str, Any], job_id: str) -> WebsiteCrawlStatusApiRequest:
|
||||
"""Create from Flask-RESTful parsed arguments."""
|
||||
provider = args.get("provider")
|
||||
if not provider:
|
||||
@ -163,7 +163,7 @@ class WebsiteService:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
|
||||
def _get_decrypted_api_key(cls, tenant_id: str, config: dict[str, Any]) -> str:
|
||||
"""Decrypt and return the API key from config."""
|
||||
api_key = config.get("api_key")
|
||||
if not api_key:
|
||||
@ -171,7 +171,7 @@ class WebsiteService:
|
||||
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
def document_create_args_validate(cls, args: dict[str, Any]):
|
||||
"""Validate arguments for document creation."""
|
||||
try:
|
||||
WebsiteCrawlApiRequest.from_args(args)
|
||||
@ -195,7 +195,7 @@ class WebsiteService:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
|
||||
params: dict[str, Any]
|
||||
@ -225,7 +225,7 @@ class WebsiteService:
|
||||
return {"status": "active", "job_id": job_id}
|
||||
|
||||
@classmethod
|
||||
def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
# Convert CrawlOptions back to dict format for WaterCrawlProvider
|
||||
options = {
|
||||
"limit": request.options.limit,
|
||||
@ -290,7 +290,7 @@ class WebsiteService:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> CrawlStatusDict:
|
||||
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> CrawlStatusDict:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data: CrawlStatusDict = {
|
||||
@ -364,7 +364,9 @@ class WebsiteService:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
def _get_firecrawl_url_data(
|
||||
cls, job_id: str, url: str, api_key: str, config: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
crawl_data: list[FirecrawlDocumentData] | None = None
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
@ -438,7 +440,7 @@ class WebsiteService:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
params = {"onlyMainContent": request.only_main_content}
|
||||
return dict(firecrawl_app.scrape_url(url=request.url, params=params))
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@ -19,6 +20,14 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class WorkflowRunListArgs(TypedDict, total=False):
|
||||
"""Expected shape of the args dict passed to workflow run pagination methods."""
|
||||
|
||||
limit: int
|
||||
last_id: str
|
||||
status: str
|
||||
|
||||
|
||||
class WorkflowRunService:
|
||||
_session_factory: sessionmaker
|
||||
_workflow_run_repo: APIWorkflowRunRepository
|
||||
@ -37,7 +46,10 @@ class WorkflowRunService:
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory)
|
||||
|
||||
def get_paginate_advanced_chat_workflow_runs(
|
||||
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
self,
|
||||
app_model: App,
|
||||
args: WorkflowRunListArgs,
|
||||
triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
@ -73,7 +85,10 @@ class WorkflowRunService:
|
||||
return pagination
|
||||
|
||||
def get_paginate_workflow_runs(
|
||||
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
self,
|
||||
app_model: App,
|
||||
args: WorkflowRunListArgs,
|
||||
triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get workflow run list
|
||||
|
||||
@ -11,6 +11,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -28,12 +29,12 @@ class TestCreateSegmentToIndexTask:
|
||||
"""Clean up database and Redis before each test to ensure isolation."""
|
||||
|
||||
# Clear all test data using fixture session
|
||||
db_session_with_containers.query(DocumentSegment).delete()
|
||||
db_session_with_containers.query(Document).delete()
|
||||
db_session_with_containers.query(Dataset).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.execute(delete(DocumentSegment))
|
||||
db_session_with_containers.execute(delete(Document))
|
||||
db_session_with_containers.execute(delete(Dataset))
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.execute(delete(Account))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
|
||||
@ -14,6 +14,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import delete
|
||||
|
||||
from libs.email_i18n import EmailType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@ -41,9 +42,9 @@ class TestSendEmailCodeLoginMailTask:
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
# Clear all test data
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.execute(delete(Account))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.nodes.human_input.entities import HumanInputNodeData
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from sqlalchemy import delete
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
@ -30,14 +31,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(db_session_with_containers):
|
||||
db_session_with_containers.query(HumanInputFormRecipient).delete()
|
||||
db_session_with_containers.query(HumanInputDelivery).delete()
|
||||
db_session_with_containers.query(HumanInputForm).delete()
|
||||
db_session_with_containers.query(WorkflowPause).delete()
|
||||
db_session_with_containers.query(WorkflowRun).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.execute(delete(HumanInputFormRecipient))
|
||||
db_session_with_containers.execute(delete(HumanInputDelivery))
|
||||
db_session_with_containers.execute(delete(HumanInputForm))
|
||||
db_session_with_containers.execute(delete(WorkflowPause))
|
||||
db_session_with_containers.execute(delete(WorkflowRun))
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.execute(delete(Account))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.email_i18n import EmailType
|
||||
@ -44,9 +45,9 @@ class TestMailInviteMemberTask:
|
||||
def cleanup_database(self, db_session_with_containers):
|
||||
"""Clean up database before each test to ensure isolation."""
|
||||
# Clear all test data
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.execute(delete(Account))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
@ -491,10 +492,10 @@ class TestMailInviteMemberTask:
|
||||
assert tenant.name is not None
|
||||
|
||||
# Verify tenant relationship exists
|
||||
tenant_join = (
|
||||
db_session_with_containers.query(TenantAccountJoin)
|
||||
.filter_by(tenant_id=tenant.id, account_id=pending_account.id)
|
||||
.first()
|
||||
tenant_join = db_session_with_containers.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == pending_account.id)
|
||||
.limit(1)
|
||||
)
|
||||
assert tenant_join is not None
|
||||
assert tenant_join.role == TenantAccountRole.NORMAL
|
||||
|
||||
@ -4,6 +4,7 @@ from unittest.mock import ANY, call, patch
|
||||
import pytest
|
||||
from graphon.variables.segments import StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.storage.storage_type import StorageType
|
||||
@ -20,11 +21,11 @@ from tasks.remove_app_and_related_data_task import (
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(db_session_with_containers):
|
||||
db_session_with_containers.query(WorkflowDraftVariable).delete()
|
||||
db_session_with_containers.query(WorkflowDraftVariableFile).delete()
|
||||
db_session_with_containers.query(UploadFile).delete()
|
||||
db_session_with_containers.query(App).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.execute(delete(WorkflowDraftVariable))
|
||||
db_session_with_containers.execute(delete(WorkflowDraftVariableFile))
|
||||
db_session_with_containers.execute(delete(UploadFile))
|
||||
db_session_with_containers.execute(delete(App))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
@ -127,21 +128,21 @@ class TestDeleteDraftVariablesBatch:
|
||||
result = delete_draft_variables_batch(app1.id, batch_size=100)
|
||||
|
||||
assert result == 150
|
||||
app1_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
|
||||
WorkflowDraftVariable.app_id == app1.id
|
||||
app1_remaining_count = db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app1.id)
|
||||
)
|
||||
app2_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
|
||||
WorkflowDraftVariable.app_id == app2.id
|
||||
app2_remaining_count = db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app2.id)
|
||||
)
|
||||
assert app1_remaining.count() == 0
|
||||
assert app2_remaining.count() == 100
|
||||
assert app1_remaining_count == 0
|
||||
assert app2_remaining_count == 100
|
||||
|
||||
def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers):
|
||||
"""Test deletion when no draft variables exist for the app."""
|
||||
result = delete_draft_variables_batch(str(uuid.uuid4()), 1000)
|
||||
|
||||
assert result == 0
|
||||
assert db_session_with_containers.query(WorkflowDraftVariable).count() == 0
|
||||
assert db_session_with_containers.scalar(select(func.count()).select_from(WorkflowDraftVariable)) == 0
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.logger")
|
||||
@ -190,12 +191,16 @@ class TestDeleteDraftVariableOffloadData:
|
||||
expected_storage_calls = [call(storage_key) for storage_key in upload_file_keys]
|
||||
mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True)
|
||||
|
||||
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
|
||||
WorkflowDraftVariableFile.id.in_(file_ids)
|
||||
remaining_var_files_count = db_session_with_containers.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(file_ids))
|
||||
)
|
||||
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
assert remaining_var_files.count() == 0
|
||||
assert remaining_upload_files.count() == 0
|
||||
remaining_upload_files_count = db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
)
|
||||
assert remaining_var_files_count == 0
|
||||
assert remaining_upload_files_count == 0
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
@patch("tasks.remove_app_and_related_data_task.logging")
|
||||
@ -217,9 +222,13 @@ class TestDeleteDraftVariableOffloadData:
|
||||
assert result == 1
|
||||
mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", storage_keys[0])
|
||||
|
||||
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
|
||||
WorkflowDraftVariableFile.id.in_(file_ids)
|
||||
remaining_var_files_count = db_session_with_containers.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(file_ids))
|
||||
)
|
||||
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
assert remaining_var_files.count() == 0
|
||||
assert remaining_upload_files.count() == 0
|
||||
remaining_upload_files_count = db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
)
|
||||
assert remaining_var_files_count == 0
|
||||
assert remaining_upload_files_count == 0
|
||||
|
||||
@ -11,6 +11,7 @@ from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@ -40,9 +41,9 @@ def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[T
|
||||
yield tenant, account
|
||||
|
||||
# Cleanup
|
||||
db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
|
||||
db_session_with_containers.query(Account).filter_by(id=account.id).delete()
|
||||
db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete()
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == tenant.id))
|
||||
db_session_with_containers.execute(delete(Account).where(Account.id == account.id))
|
||||
db_session_with_containers.execute(delete(Tenant).where(Tenant.id == tenant.id))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
@ -93,14 +94,14 @@ def app_model(
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
|
||||
db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete()
|
||||
db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(App).filter_by(id=app.id).delete()
|
||||
db_session_with_containers.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(WorkflowPluginTrigger).where(WorkflowPluginTrigger.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(AppTrigger).where(AppTrigger.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant.id))
|
||||
db_session_with_containers.execute(delete(Workflow).where(Workflow.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(App).where(App.id == app.id))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ import pytest
|
||||
from flask import Flask, Response
|
||||
from flask.testing import FlaskClient
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@ -227,7 +228,9 @@ def test_webhook_trigger_creates_trigger_log(
|
||||
assert response.status_code == 200
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
|
||||
logs = db_session_with_containers.scalars(
|
||||
select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id)
|
||||
).all()
|
||||
assert logs, "Webhook trigger should create trigger log"
|
||||
|
||||
|
||||
@ -611,7 +614,9 @@ def test_schedule_trigger_creates_trigger_log(
|
||||
|
||||
# Verify WorkflowTriggerLog was created
|
||||
db_session_with_containers.expire_all()
|
||||
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
|
||||
logs = db_session_with_containers.scalars(
|
||||
select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id)
|
||||
).all()
|
||||
assert logs, "Schedule trigger should create WorkflowTriggerLog"
|
||||
assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE
|
||||
assert logs[0].root_node_id == schedule_node_id
|
||||
@ -786,11 +791,12 @@ def test_plugin_trigger_full_chain_with_db_verification(
|
||||
|
||||
# Verify database records exist
|
||||
db_session_with_containers.expire_all()
|
||||
plugin_triggers = (
|
||||
db_session_with_containers.query(WorkflowPluginTrigger)
|
||||
.filter_by(app_id=app_model.id, node_id=plugin_node_id)
|
||||
.all()
|
||||
)
|
||||
plugin_triggers = db_session_with_containers.scalars(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app_model.id,
|
||||
WorkflowPluginTrigger.node_id == plugin_node_id,
|
||||
)
|
||||
).all()
|
||||
assert plugin_triggers, "WorkflowPluginTrigger record should exist"
|
||||
assert plugin_triggers[0].provider_id == provider_id
|
||||
assert plugin_triggers[0].event_name == "test_event"
|
||||
|
||||
@ -14,18 +14,20 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailPasswordLoginLimitError,
|
||||
InvalidEmailError,
|
||||
)
|
||||
from controllers.console.auth.login import LoginApi, LogoutApi
|
||||
from controllers.console.auth.login import EmailCodeLoginApi, LoginApi, LogoutApi
|
||||
from controllers.console.error import (
|
||||
AccountBannedError,
|
||||
AccountInFreezeError,
|
||||
WorkspacesLimitExceeded,
|
||||
)
|
||||
from services.entities.auth_entities import LoginFailureReason
|
||||
from services.errors.account import AccountLoginError, AccountPasswordError
|
||||
|
||||
|
||||
@ -34,6 +36,11 @@ def encode_password(password: str) -> str:
|
||||
return base64.b64encode(password.encode("utf-8")).decode()
|
||||
|
||||
|
||||
def encode_code(code: str) -> str:
|
||||
"""Helper to encode verification code as Base64 for testing."""
|
||||
return base64.b64encode(code.encode("utf-8")).decode()
|
||||
|
||||
|
||||
class TestLoginApi:
|
||||
"""Test cases for the LoginApi endpoint."""
|
||||
|
||||
@ -197,12 +204,17 @@ class TestLoginApi:
|
||||
mock_get_invitation.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(EmailPasswordLoginLimitError):
|
||||
login_api.post()
|
||||
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(EmailPasswordLoginLimitError):
|
||||
login_api.post()
|
||||
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "test@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.LOGIN_RATE_LIMITED
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True)
|
||||
@ -220,12 +232,17 @@ class TestLoginApi:
|
||||
mock_is_frozen.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AccountInFreezeError):
|
||||
login_api.post()
|
||||
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AccountInFreezeError):
|
||||
login_api.post()
|
||||
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "frozen@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_IN_FREEZE
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@ -257,14 +274,20 @@ class TestLoginApi:
|
||||
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("WrongPass123!")}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
login_api.post()
|
||||
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/login",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "password": encode_password("WrongPass123!")},
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
login_api.post()
|
||||
|
||||
mock_add_rate_limit.assert_called_once_with("test@example.com")
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "test@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@ -288,12 +311,19 @@ class TestLoginApi:
|
||||
mock_authenticate.side_effect = AccountLoginError("Account is banned")
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "banned@example.com", "password": encode_password("ValidPass123!")}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AccountBannedError):
|
||||
login_api.post()
|
||||
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/login",
|
||||
method="POST",
|
||||
json={"email": "banned@example.com", "password": encode_password("ValidPass123!")},
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AccountBannedError):
|
||||
login_api.post()
|
||||
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "banned@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@ -417,6 +447,36 @@ class TestLoginApi:
|
||||
mock_add_rate_limit.assert_not_called()
|
||||
mock_reset_rate_limit.assert_called_once_with("upper@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
|
||||
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
|
||||
@patch("controllers.console.auth.login._get_account_with_case_fallback")
|
||||
def test_email_code_login_logs_banned_account(
|
||||
self,
|
||||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_token_data,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
|
||||
mock_get_account.side_effect = Unauthorized("Account is banned.")
|
||||
|
||||
with patch("controllers.console.auth.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
|
||||
):
|
||||
with pytest.raises(AccountBannedError):
|
||||
EmailCodeLoginApi().post()
|
||||
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "user@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
|
||||
|
||||
|
||||
class TestLogoutApi:
|
||||
"""Test cases for the LogoutApi endpoint."""
|
||||
|
||||
@ -4,9 +4,12 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from jwt import InvalidTokenError
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services.errors.account
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi
|
||||
from services.entities.auth_entities import LoginFailureReason
|
||||
|
||||
|
||||
def encode_code(code: str) -> str:
|
||||
@ -115,13 +118,18 @@ class TestLoginApi:
|
||||
def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None:
|
||||
from controllers.console.error import AccountBannedError
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AccountBannedError):
|
||||
LoginApi().post()
|
||||
with patch("controllers.web.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AccountBannedError):
|
||||
LoginApi().post()
|
||||
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "user@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
|
||||
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.authenticate",
|
||||
@ -130,13 +138,87 @@ class TestLoginApi:
|
||||
def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None:
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
LoginApi().post()
|
||||
with patch("controllers.web.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
LoginApi().post()
|
||||
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "user@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS
|
||||
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.authenticate",
|
||||
side_effect=services.errors.account.AccountNotFoundError(),
|
||||
)
|
||||
def test_login_account_not_found(self, mock_auth: MagicMock, app: Flask) -> None:
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
|
||||
with patch("controllers.web.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "missing@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
LoginApi().post()
|
||||
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "missing@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_NOT_FOUND
|
||||
|
||||
@patch("controllers.web.login.WebAppAuthService.get_email_code_login_data", return_value=None)
|
||||
def test_email_code_login_logs_invalid_token(self, mock_get_token_data: MagicMock, app: Flask) -> None:
|
||||
with patch("controllers.web.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/web/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "code": encode_code("123456"), "token": "token-123"},
|
||||
):
|
||||
with pytest.raises(InvalidTokenError):
|
||||
EmailCodeLoginApi().post()
|
||||
|
||||
mock_get_token_data.assert_called_once_with("token-123")
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "user@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_EMAIL_CODE_TOKEN
|
||||
|
||||
@patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token")
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.get_user_through_email",
|
||||
side_effect=Unauthorized("Account is banned."),
|
||||
)
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.get_email_code_login_data",
|
||||
return_value={"email": "User@Example.com", "code": "123456"},
|
||||
)
|
||||
def test_email_code_login_logs_banned_account(
|
||||
self,
|
||||
mock_get_token_data: MagicMock,
|
||||
mock_get_user: MagicMock,
|
||||
mock_revoke_token: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
from controllers.console.error import AccountBannedError
|
||||
|
||||
with patch("controllers.web.login.logger.warning") as mock_log_warning:
|
||||
with app.test_request_context(
|
||||
"/web/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
|
||||
):
|
||||
with pytest.raises(AccountBannedError):
|
||||
EmailCodeLoginApi().post()
|
||||
|
||||
mock_get_token_data.assert_called_once_with("token-123")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
assert mock_log_warning.call_count == 1
|
||||
assert mock_log_warning.call_args.args[1] == "user@example.com"
|
||||
assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED
|
||||
|
||||
|
||||
class TestLoginStatusApi:
|
||||
|
||||
63
pnpm-lock.yaml
generated
63
pnpm-lock.yaml
generated
@ -16,14 +16,17 @@ catalogs:
|
||||
specifier: 8.1.1
|
||||
version: 8.1.1
|
||||
'@base-ui/react':
|
||||
specifier: 1.3.0
|
||||
version: 1.3.0
|
||||
specifier: 1.4.0
|
||||
version: 1.4.0
|
||||
'@chromatic-com/storybook':
|
||||
specifier: 5.1.1
|
||||
version: 5.1.1
|
||||
'@cucumber/cucumber':
|
||||
specifier: 12.7.0
|
||||
version: 12.7.0
|
||||
'@date-fns/tz':
|
||||
specifier: 1.2.0
|
||||
version: 1.2.0
|
||||
'@egoist/tailwindcss-icons':
|
||||
specifier: 1.9.2
|
||||
version: 1.9.2
|
||||
@ -267,6 +270,9 @@ catalogs:
|
||||
cron-parser:
|
||||
specifier: 5.5.0
|
||||
version: 5.5.0
|
||||
date-fns:
|
||||
specifier: 4.0.0
|
||||
version: 4.0.0
|
||||
dayjs:
|
||||
specifier: 1.11.20
|
||||
version: 1.11.20
|
||||
@ -433,8 +439,8 @@ catalogs:
|
||||
specifier: 5.2.4
|
||||
version: 5.2.4
|
||||
react-i18next:
|
||||
specifier: 17.0.2
|
||||
version: 17.0.2
|
||||
specifier: 16.5.8
|
||||
version: 16.5.8
|
||||
react-multi-email:
|
||||
specifier: 1.0.25
|
||||
version: 1.0.25
|
||||
@ -648,7 +654,10 @@ importers:
|
||||
version: 1.27.6(@amplitude/rrweb@2.0.0-alpha.37)(rollup@4.59.0)
|
||||
'@base-ui/react':
|
||||
specifier: 'catalog:'
|
||||
version: 1.3.0(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
|
||||
version: 1.4.0(@date-fns/tz@1.2.0)(@types/react@19.2.14)(date-fns@4.0.0)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
|
||||
'@date-fns/tz':
|
||||
specifier: 'catalog:'
|
||||
version: 1.2.0
|
||||
'@emoji-mart/data':
|
||||
specifier: 'catalog:'
|
||||
version: 1.2.1
|
||||
@ -751,6 +760,9 @@ importers:
|
||||
cron-parser:
|
||||
specifier: 'catalog:'
|
||||
version: 5.5.0
|
||||
date-fns:
|
||||
specifier: 'catalog:'
|
||||
version: 4.0.0
|
||||
dayjs:
|
||||
specifier: 'catalog:'
|
||||
version: 1.11.20
|
||||
@ -876,7 +888,7 @@ importers:
|
||||
version: 5.2.4(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
|
||||
react-i18next:
|
||||
specifier: 'catalog:'
|
||||
version: 17.0.2(i18next@26.0.4(typescript@6.0.2))(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(typescript@6.0.2)
|
||||
version: 16.5.8(i18next@26.0.4(typescript@6.0.2))(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(typescript@6.0.2)
|
||||
react-multi-email:
|
||||
specifier: 'catalog:'
|
||||
version: 1.0.25(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
|
||||
@ -1391,19 +1403,21 @@ packages:
|
||||
resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
'@base-ui/react@1.3.0':
|
||||
resolution: {integrity: sha512-FwpKqZbPz14AITp1CVgf4AjhKPe1OeeVKSBMdgD10zbFlj3QSWelmtCMLi2+/PFZZcIm3l87G7rwtCZJwHyXWA==}
|
||||
'@base-ui/react@1.4.0':
|
||||
resolution: {integrity: sha512-QcqdVbr/+ba2/RAKJIV1PV6S02Q5+r6a4Eym8ndBw+ZbBILkkmQAyRxXCg/pArrHnkrGeU8goe26aw0h6eE8pg==}
|
||||
engines: {node: '>=14.0.0'}
|
||||
peerDependencies:
|
||||
'@date-fns/tz': ^1.2.0
|
||||
'@types/react': ^17 || ^18 || ^19
|
||||
date-fns: ^4.0.0
|
||||
react: ^17 || ^18 || ^19
|
||||
react-dom: ^17 || ^18 || ^19
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
|
||||
'@base-ui/utils@0.2.6':
|
||||
resolution: {integrity: sha512-yQ+qeuqohwhsNpoYDqqXaLllYAkPCP4vYdDrVo8FQXaAPfHWm1pG/Vm+jmGTA5JFS0BAIjookyapuJFY8F9PIw==}
|
||||
'@base-ui/utils@0.2.7':
|
||||
resolution: {integrity: sha512-nXYKhiL/0JafyJE8PfcflipGftOftlIwKd72rU15iZ1M5yqgg5J9P8NHU71GReDuXco5MJA/eVQqUT5WRqX9sA==}
|
||||
peerDependencies:
|
||||
'@types/react': ^17 || ^18 || ^19
|
||||
react: ^17 || ^18 || ^19
|
||||
@ -1532,6 +1546,9 @@ packages:
|
||||
'@cucumber/tag-expressions@9.1.0':
|
||||
resolution: {integrity: sha512-bvHjcRFZ+J1TqIa9eFNO1wGHqwx4V9ZKV3hYgkuK/VahHx73uiP4rKV3JVrvWSMrwrFvJG6C8aEwnCWSvbyFdQ==}
|
||||
|
||||
'@date-fns/tz@1.2.0':
|
||||
resolution: {integrity: sha512-LBrd7MiJZ9McsOgxqWX7AaxrDjcFVjWH/tIKJd7pnR7McaslGYOP1QmmiBXdJH/H/yLCT+rcQ7FaPBUxRGUtrg==}
|
||||
|
||||
'@e18e/eslint-plugin@0.3.0':
|
||||
resolution: {integrity: sha512-hHgfpxsrZ2UYHcicA+tGZnmk19uJTaye9VH79O+XS8R4ona2Hx3xjhXghclNW58uXMk3xXlbYEOMr8thsoBmWg==}
|
||||
peerDependencies:
|
||||
@ -5325,6 +5342,9 @@ packages:
|
||||
dagre-d3-es@7.0.14:
|
||||
resolution: {integrity: sha512-P4rFMVq9ESWqmOgK+dlXvOtLwYg0i7u0HBGJER0LZDJT2VHIPAMZ/riPxqJceWMStH5+E61QxFra9kIS3AqdMg==}
|
||||
|
||||
date-fns@4.0.0:
|
||||
resolution: {integrity: sha512-6K33+I8fQ5otvHgLIvKK1xmMbLAh0pduyrx7dwMXKiGYeoWhmk6M3Zoak9n7bXHMJQlHq1yqmdGy1QxKddJjUA==}
|
||||
|
||||
dayjs@1.11.20:
|
||||
resolution: {integrity: sha512-YbwwqR/uYpeoP4pu043q+LTDLFBLApUP6VxRihdfNTqu4ubqMlGDLd6ErXhEgsyvY0K6nCs7nggYumAN+9uEuQ==}
|
||||
|
||||
@ -7376,14 +7396,14 @@ packages:
|
||||
react: '>=16.8.0'
|
||||
react-dom: '>=16.8.0'
|
||||
|
||||
react-i18next@17.0.2:
|
||||
resolution: {integrity: sha512-shBftH2vaTWK2Bsp7FiL+cevx3xFJlvFxmsDFQSrJc+6twHkP0tv/bGa01VVWzpreUVVwU+3Hev5iFqRg65RwA==}
|
||||
react-i18next@16.5.8:
|
||||
resolution: {integrity: sha512-2ABeHHlakxVY+LSirD+OiERxFL6+zip0PaHo979bgwzeHg27Sqc82xxXWIrSFmfWX0ZkrvXMHwhsi/NGUf5VQg==}
|
||||
peerDependencies:
|
||||
i18next: '>= 26.0.1'
|
||||
i18next: '>= 25.6.2'
|
||||
react: '>= 16.8.0'
|
||||
react-dom: '*'
|
||||
react-native: '*'
|
||||
typescript: ^5 || ^6
|
||||
typescript: ^5
|
||||
peerDependenciesMeta:
|
||||
react-dom:
|
||||
optional: true
|
||||
@ -8896,20 +8916,21 @@ snapshots:
|
||||
'@babel/helper-string-parser': 7.27.1
|
||||
'@babel/helper-validator-identifier': 7.28.5
|
||||
|
||||
'@base-ui/react@1.3.0(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
|
||||
'@base-ui/react@1.4.0(@date-fns/tz@1.2.0)(@types/react@19.2.14)(date-fns@4.0.0)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
|
||||
dependencies:
|
||||
'@babel/runtime': 7.29.2
|
||||
'@base-ui/utils': 0.2.6(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
|
||||
'@base-ui/utils': 0.2.7(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
|
||||
'@date-fns/tz': 1.2.0
|
||||
'@floating-ui/react-dom': 2.1.8(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
|
||||
'@floating-ui/utils': 0.2.11
|
||||
date-fns: 4.0.0
|
||||
react: 19.2.5
|
||||
react-dom: 19.2.5(react@19.2.5)
|
||||
tabbable: 6.4.0
|
||||
use-sync-external-store: 1.6.0(react@19.2.5)
|
||||
optionalDependencies:
|
||||
'@types/react': 19.2.14
|
||||
|
||||
'@base-ui/utils@0.2.6(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
|
||||
'@base-ui/utils@0.2.7(@types/react@19.2.14)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
|
||||
dependencies:
|
||||
'@babel/runtime': 7.29.2
|
||||
'@floating-ui/utils': 0.2.11
|
||||
@ -9128,6 +9149,8 @@ snapshots:
|
||||
|
||||
'@cucumber/tag-expressions@9.1.0': {}
|
||||
|
||||
'@date-fns/tz@1.2.0': {}
|
||||
|
||||
'@e18e/eslint-plugin@0.3.0(eslint@10.2.0(jiti@2.6.1))(oxlint@1.58.0(oxlint-tsgolint@0.20.0))':
|
||||
dependencies:
|
||||
eslint-plugin-depend: 1.5.0(eslint@10.2.0(jiti@2.6.1))
|
||||
@ -12793,6 +12816,8 @@ snapshots:
|
||||
d3: 7.9.0
|
||||
lodash-es: 4.18.0
|
||||
|
||||
date-fns@4.0.0: {}
|
||||
|
||||
dayjs@1.11.20: {}
|
||||
|
||||
debug@4.4.3(supports-color@8.1.1):
|
||||
@ -15458,7 +15483,7 @@ snapshots:
|
||||
react: 19.2.5
|
||||
react-dom: 19.2.5(react@19.2.5)
|
||||
|
||||
react-i18next@17.0.2(i18next@26.0.4(typescript@6.0.2))(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(typescript@6.0.2):
|
||||
react-i18next@16.5.8(i18next@26.0.4(typescript@6.0.2))(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(typescript@6.0.2):
|
||||
dependencies:
|
||||
'@babel/runtime': 7.29.2
|
||||
html-parse-stringify: 3.0.1
|
||||
|
||||
@ -48,7 +48,8 @@ catalog:
|
||||
"@amplitude/analytics-browser": 2.38.1
|
||||
"@amplitude/plugin-session-replay-browser": 1.27.6
|
||||
"@antfu/eslint-config": 8.1.1
|
||||
"@base-ui/react": 1.3.0
|
||||
"@base-ui/react": 1.4.0
|
||||
"@date-fns/tz": 1.2.0
|
||||
"@chromatic-com/storybook": 5.1.1
|
||||
"@cucumber/cucumber": 12.7.0
|
||||
"@egoist/tailwindcss-icons": 1.9.2
|
||||
@ -135,6 +136,7 @@ catalog:
|
||||
code-inspector-plugin: 1.5.1
|
||||
copy-to-clipboard: 3.3.3
|
||||
cron-parser: 5.5.0
|
||||
date-fns: 4.0.0
|
||||
dayjs: 1.11.20
|
||||
decimal.js: 10.6.0
|
||||
dompurify: 3.3.3
|
||||
@ -191,7 +193,7 @@ catalog:
|
||||
react-dom: 19.2.5
|
||||
react-easy-crop: 5.5.7
|
||||
react-hotkeys-hook: 5.2.4
|
||||
react-i18next: 17.0.2
|
||||
react-i18next: 16.5.8
|
||||
react-multi-email: 1.0.25
|
||||
react-papaparse: 4.4.0
|
||||
react-pdf-highlighter: 8.0.0-rc.0
|
||||
|
||||
@ -41,6 +41,33 @@ const renderOpenSelect = ({
|
||||
}
|
||||
|
||||
describe('Select wrappers', () => {
|
||||
describe('Select root integration', () => {
|
||||
it('should associate the hidden input with an external form and preserve autocomplete hints', () => {
|
||||
const formId = 'profile-form'
|
||||
const { container } = render(
|
||||
<>
|
||||
<form id={formId} />
|
||||
<Select defaultValue="seattle" name="city" form={formId} autoComplete="address-level2">
|
||||
<SelectTrigger aria-label="city select">
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent listProps={{ 'role': 'listbox', 'aria-label': 'select list' }}>
|
||||
<SelectItem value="seattle">Seattle</SelectItem>
|
||||
<SelectItem value="new-york">New York</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</>,
|
||||
)
|
||||
|
||||
const hiddenInput = container.querySelector('input[name="city"]')
|
||||
const form = container.querySelector(`#${formId}`) as HTMLFormElement
|
||||
|
||||
expect(hiddenInput).toHaveAttribute('form', formId)
|
||||
expect(hiddenInput).toHaveAttribute('autocomplete', 'address-level2')
|
||||
expect(new FormData(form).get('city')).toBe('seattle')
|
||||
})
|
||||
})
|
||||
|
||||
describe('SelectTrigger', () => {
|
||||
it('should render clear button when clearable is true and loading is false', () => {
|
||||
renderOpenSelect({
|
||||
|
||||
@ -47,6 +47,23 @@ describe('Slider', () => {
|
||||
expect(onValueChange).toHaveBeenLastCalledWith(21, expect.anything())
|
||||
})
|
||||
|
||||
it('should round floating point keyboard updates to the configured step', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onValueChange = vi.fn()
|
||||
|
||||
render(<Slider value={0.2} min={0} max={1} step={0.1} onValueChange={onValueChange} aria-label="Value" />)
|
||||
|
||||
const slider = getSliderInput()
|
||||
|
||||
await act(async () => {
|
||||
slider.focus()
|
||||
await user.keyboard('{ArrowRight}')
|
||||
})
|
||||
|
||||
expect(onValueChange).toHaveBeenCalledTimes(1)
|
||||
expect(onValueChange).toHaveBeenLastCalledWith(0.3, expect.anything())
|
||||
})
|
||||
|
||||
it('should not trigger onValueChange when disabled', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onValueChange = vi.fn()
|
||||
|
||||
@ -251,6 +251,32 @@ describe('base/ui/toast', () => {
|
||||
expect(screen.queryByText('Loading')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Re-adding the same toast id should upsert in place instead of stacking duplicates.
|
||||
it('should upsert an existing toast when add is called with the same id', async () => {
|
||||
render(<ToastHost />)
|
||||
|
||||
act(() => {
|
||||
toast('Syncing', {
|
||||
id: 'sync-job',
|
||||
description: 'Uploading changes…',
|
||||
})
|
||||
})
|
||||
|
||||
expect(await screen.findByText('Syncing')).toBeInTheDocument()
|
||||
|
||||
act(() => {
|
||||
toast.success('Synced', {
|
||||
id: 'sync-job',
|
||||
description: 'All changes are uploaded.',
|
||||
})
|
||||
})
|
||||
|
||||
expect(screen.queryByText('Syncing')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('Synced')).toBeInTheDocument()
|
||||
expect(screen.getByText('All changes are uploaded.')).toBeInTheDocument()
|
||||
expect(screen.getAllByRole('dialog')).toHaveLength(1)
|
||||
})
|
||||
|
||||
// Action props should pass through to the Base UI action button.
|
||||
it('should render and invoke toast action props', async () => {
|
||||
const onAction = vi.fn()
|
||||
|
||||
@ -501,6 +501,34 @@ describe('http path', () => {
|
||||
expect(onChange).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should only append a new key-value row after the last value field receives content', () => {
|
||||
const onChange = vi.fn()
|
||||
const onRemove = vi.fn()
|
||||
const onAdd = vi.fn()
|
||||
render(
|
||||
<KeyValueItem
|
||||
instanceId="kv-append"
|
||||
nodeId="node-1"
|
||||
readonly={false}
|
||||
canRemove
|
||||
payload={{ id: 'kv-append', key: 'name', value: '', type: 'text' } as any}
|
||||
onChange={onChange}
|
||||
onRemove={onRemove}
|
||||
isLastItem
|
||||
onAdd={onAdd}
|
||||
/>,
|
||||
)
|
||||
|
||||
const valueInput = screen.getAllByPlaceholderText('workflow.nodes.http.insertVarPlaceholder')[1]!
|
||||
|
||||
fireEvent.click(valueInput)
|
||||
expect(onAdd).not.toHaveBeenCalled()
|
||||
|
||||
fireEvent.change(valueInput, { target: { value: 'alice' } })
|
||||
expect(onChange).toHaveBeenCalledWith(expect.objectContaining({ value: 'alice' }))
|
||||
expect(onAdd).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should edit key-only rows and select file payload rows', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onChange = vi.fn()
|
||||
|
||||
@ -47,20 +47,37 @@ const KeyValueItem: FC<Props> = ({
|
||||
insertVarTipToLeft,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const hasValuePayload = payload.type === 'file'
|
||||
? !!payload.file?.length
|
||||
: !!payload.value
|
||||
|
||||
const handleChange = useCallback((key: string) => {
|
||||
return (value: string | ValueSelector) => {
|
||||
const shouldAddNextItem = isLastItem
|
||||
&& (
|
||||
(key === 'value' && !payload.value && !!value)
|
||||
|| (key === 'file' && (!payload.file || payload.file.length === 0) && Array.isArray(value) && value.length > 0)
|
||||
)
|
||||
|
||||
const newPayload = produce(payload, (draft: any) => {
|
||||
draft[key] = value
|
||||
})
|
||||
onChange(newPayload)
|
||||
|
||||
if (shouldAddNextItem)
|
||||
onAdd()
|
||||
}
|
||||
}, [onChange, payload])
|
||||
}, [isLastItem, onAdd, onChange, payload])
|
||||
|
||||
const filterOnlyFileVariable = (varPayload: Var) => {
|
||||
return [VarType.file, VarType.arrayFile].includes(varPayload.type)
|
||||
}
|
||||
|
||||
const handleValueContainerClick = useCallback(() => {
|
||||
if (isLastItem && hasValuePayload)
|
||||
onAdd()
|
||||
}, [hasValuePayload, isLastItem, onAdd])
|
||||
|
||||
return (
|
||||
// group class name is for hover row show remove button
|
||||
<div className={cn(className, 'h-min-7 group flex border-t border-divider-regular')}>
|
||||
@ -102,7 +119,10 @@ const KeyValueItem: FC<Props> = ({
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div className={cn(isSupportFile ? 'grow' : 'w-1/2')} onClick={() => isLastItem && onAdd()}>
|
||||
<div
|
||||
className={cn(isSupportFile ? 'grow' : 'w-1/2')}
|
||||
onClick={handleValueContainerClick}
|
||||
>
|
||||
{(isSupportFile && payload.type === 'file')
|
||||
? (
|
||||
<VarReferencePicker
|
||||
|
||||
@ -56,6 +56,7 @@
|
||||
"@amplitude/analytics-browser": "catalog:",
|
||||
"@amplitude/plugin-session-replay-browser": "catalog:",
|
||||
"@base-ui/react": "catalog:",
|
||||
"@date-fns/tz": "catalog:",
|
||||
"@emoji-mart/data": "catalog:",
|
||||
"@floating-ui/react": "catalog:",
|
||||
"@formatjs/intl-localematcher": "catalog:",
|
||||
@ -90,6 +91,7 @@
|
||||
"cmdk": "catalog:",
|
||||
"copy-to-clipboard": "catalog:",
|
||||
"cron-parser": "catalog:",
|
||||
"date-fns": "catalog:",
|
||||
"dayjs": "catalog:",
|
||||
"decimal.js": "catalog:",
|
||||
"dompurify": "catalog:",
|
||||
|
||||
@ -1,11 +1,21 @@
|
||||
import type { Schema } from 'jsonschema'
|
||||
import type { Schema, ValidationError, ValidatorResult } from 'jsonschema'
|
||||
import { Validator } from 'jsonschema'
|
||||
import draft07Schema from './draft-07.json'
|
||||
|
||||
const validator = new Validator()
|
||||
|
||||
export const draft07Validator = (schema: any) => {
|
||||
return validator.validate(schema, draft07Schema as unknown as Schema)
|
||||
type Draft07ValidationResult = Pick<ValidatorResult, 'valid' | 'errors'>
|
||||
|
||||
export const draft07Validator = (schema: any): Draft07ValidationResult => {
|
||||
try {
|
||||
return validator.validate(schema, draft07Schema as unknown as Schema)
|
||||
}
|
||||
catch {
|
||||
// The jsonschema library may throw URL errors in browser environments
|
||||
// when resolving schema $id URIs. Return empty errors since structural
|
||||
// validation is handled separately by preValidateSchema (#34841).
|
||||
return { valid: true, errors: [] as ValidationError[] }
|
||||
}
|
||||
}
|
||||
|
||||
export const forbidBooleanProperties = (schema: any, path: string[] = []): string[] => {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user