mirror of https://github.com/langgenius/dify.git
refactor: migrate some ns.model to BaseModel (#30388)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
e3ef33366d
commit
5b02e5dcb6
|
|
@ -1,62 +1,59 @@
|
|||
from flask_restx import Api, Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.helper import AppIconUrlField
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
parameters__system_parameters = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
from models.model import IconType
|
||||
|
||||
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
JSONObject: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
def build_system_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the system parameters model for the API or Namespace."""
|
||||
return api_or_ns.model("SystemParameters", parameters__system_parameters)
|
||||
class SystemParameters(BaseModel):
|
||||
image_file_size_limit: int
|
||||
video_file_size_limit: int
|
||||
audio_file_size_limit: int
|
||||
file_size_limit: int
|
||||
workflow_file_upload_limit: int
|
||||
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||
}
|
||||
class Parameters(BaseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: JSONObject
|
||||
speech_to_text: JSONObject
|
||||
text_to_speech: JSONObject
|
||||
retriever_resource: JSONObject
|
||||
annotation_reply: JSONObject
|
||||
more_like_this: JSONObject
|
||||
user_input_form: list[JSONObject]
|
||||
sensitive_word_avoidance: JSONObject
|
||||
file_upload: JSONObject
|
||||
system_parameters: SystemParameters
|
||||
|
||||
|
||||
def build_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the parameters model for the API or Namespace."""
|
||||
copied_fields = parameters_fields.copy()
|
||||
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
|
||||
return api_or_ns.model("Parameters", copied_fields)
|
||||
class Site(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
title: str
|
||||
chat_color_theme: str | None = None
|
||||
chat_color_theme_inverted: bool
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
description: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
default_language: str
|
||||
show_workflow_steps: bool
|
||||
use_icon_as_answer_icon: bool
|
||||
|
||||
site_fields = {
|
||||
"title": fields.String,
|
||||
"chat_color_theme": fields.String,
|
||||
"chat_color_theme_inverted": fields.Boolean,
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"description": fields.String,
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"default_language": fields.String,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
|
||||
def build_site_model(api_or_ns: Api | Namespace):
|
||||
"""Build the site model for the API or Namespace."""
|
||||
return api_or_ns.model("Site", site_fields)
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
if self.icon and self.icon_type == IconType.IMAGE:
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from flask_restx import marshal_with
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
|
|
@ -13,7 +11,6 @@ from services.app_service import AppService
|
|||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
|
|
@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
|
|||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -92,7 +92,7 @@ annotation_list_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_annotation_list_model(api_or_ns: Api | Namespace):
|
||||
def build_annotation_list_model(api_or_ns: Namespace):
|
||||
"""Build the annotation list model for the API or Namespace."""
|
||||
copied_annotation_list_fields = annotation_list_fields.copy()
|
||||
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.fields import build_parameters_model
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
|
|
@ -23,7 +23,6 @@ class AppParameterApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters.
|
||||
|
||||
|
|
@ -45,7 +44,8 @@ class AppParameterApi(Resource):
|
|||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/meta")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.fields import build_site_model
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -23,7 +23,6 @@ class AppSiteApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_site_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app site info.
|
||||
|
||||
|
|
@ -38,4 +37,4 @@ class AppSiteApi(Resource):
|
|||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
return SiteResponse.model_validate(site).model_dump(mode="json")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Any, Literal
|
|||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
|
@ -78,7 +78,7 @@ workflow_run_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_workflow_run_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_run_model(api_or_ns: Namespace):
|
||||
"""Build the workflow run model for the API or Namespace."""
|
||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
|
|
@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
|
|
@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource):
|
|||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/meta")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ annotation_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_annotation_model(api_or_ns: Api | Namespace):
|
||||
def build_annotation_model(api_or_ns: Namespace):
|
||||
"""Build the annotation model for the API or Namespace."""
|
||||
return api_or_ns.model("Annotation", annotation_fields)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
|
|
@ -46,7 +46,7 @@ message_file_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_message_file_model(api_or_ns: Api | Namespace):
|
||||
def build_message_file_model(api_or_ns: Namespace):
|
||||
"""Build the message file fields for the API or Namespace."""
|
||||
return api_or_ns.model("MessageFile", message_file_fields)
|
||||
|
||||
|
|
@ -217,7 +217,7 @@ conversation_infinite_scroll_pagination_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
|
||||
simple_conversation_model = build_simple_conversation_model(api_or_ns)
|
||||
|
||||
|
|
@ -226,11 +226,11 @@ def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespa
|
|||
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
|
||||
|
||||
|
||||
def build_conversation_delete_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_delete_model(api_or_ns: Namespace):
|
||||
"""Build the conversation delete model for the API or Namespace."""
|
||||
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
|
||||
|
||||
|
||||
def build_simple_conversation_model(api_or_ns: Api | Namespace):
|
||||
def build_simple_conversation_model(api_or_ns: Namespace):
|
||||
"""Build the simple conversation model for the API or Namespace."""
|
||||
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
|
|
@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_conversation_variable_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_variable_model(api_or_ns: Namespace):
|
||||
"""Build the conversation variable model for the API or Namespace."""
|
||||
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
||||
|
||||
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
|
||||
# Build the nested variable model first
|
||||
conversation_variable_model = build_conversation_variable_model(api_or_ns)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
simple_end_user_fields = {
|
||||
"id": fields.String,
|
||||
|
|
@ -8,5 +8,5 @@ simple_end_user_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_simple_end_user_model(api_or_ns: Api | Namespace):
|
||||
def build_simple_end_user_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ upload_config_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_upload_config_model(api_or_ns: Api | Namespace):
|
||||
def build_upload_config_model(api_or_ns: Namespace):
|
||||
"""Build the upload config model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
|
|
@ -39,7 +39,7 @@ file_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_file_model(api_or_ns: Api | Namespace):
|
||||
def build_file_model(api_or_ns: Namespace):
|
||||
"""Build the file model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
|
|
@ -57,7 +57,7 @@ remote_file_info_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_remote_file_info_model(api_or_ns: Api | Namespace):
|
||||
def build_remote_file_info_model(api_or_ns: Namespace):
|
||||
"""Build the remote file info model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
|
|
@ -81,7 +81,7 @@ file_fields_with_signed_url = {
|
|||
}
|
||||
|
||||
|
||||
def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
|
||||
def build_file_with_signed_url_model(api_or_ns: Namespace):
|
||||
"""Build the file with signed URL model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ simple_account_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_simple_account_model(api_or_ns: Api | Namespace):
|
||||
def build_simple_account_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleAccount", simple_account_fields)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField
|
||||
|
|
@ -10,7 +10,7 @@ feedback_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_feedback_model(api_or_ns: Api | Namespace):
|
||||
def build_feedback_model(api_or_ns: Namespace):
|
||||
"""Build the feedback model for the API or Namespace."""
|
||||
return api_or_ns.model("Feedback", feedback_fields)
|
||||
|
||||
|
|
@ -30,7 +30,7 @@ agent_thought_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_agent_thought_model(api_or_ns: Api | Namespace):
|
||||
def build_agent_thought_model(api_or_ns: Namespace):
|
||||
"""Build the agent thought model for the API or Namespace."""
|
||||
return api_or_ns.model("AgentThought", agent_thought_fields)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
|
|
@ -8,5 +8,5 @@ dataset_tag_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Api | Namespace):
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
||||
from fields.member_fields import build_simple_account_model, simple_account_fields
|
||||
|
|
@ -17,7 +17,7 @@ workflow_app_log_partial_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
|
||||
"""Build the workflow app log partial model for the API or Namespace."""
|
||||
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
|
||||
simple_account_model = build_simple_account_model(api_or_ns)
|
||||
|
|
@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the workflow app log pagination model for the API or Namespace."""
|
||||
# Build the nested partial model first
|
||||
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
|
|
@ -19,7 +19,7 @@ workflow_run_for_log_fields = {
|
|||
}
|
||||
|
||||
|
||||
def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_run_for_log_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from uuid import uuid4
|
|||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, validates
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from .base import TypeBase
|
||||
|
|
@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase):
|
|||
role: TenantAccountRole | None = field(default=None, init=False)
|
||||
_current_tenant: "Tenant | None" = field(default=None, init=False)
|
||||
|
||||
@validates("status")
|
||||
def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
|
||||
if isinstance(value, AccountStatus):
|
||||
return value.value
|
||||
return value
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
return self.password is not None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
import builtins
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask.views import MethodView as FlaskMethodView
|
||||
|
||||
_NEEDS_METHOD_VIEW_CLEANUP = False
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = FlaskMethodView
|
||||
_NEEDS_METHOD_VIEW_CLEANUP = True
|
||||
from controllers.common.fields import Parameters, Site
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from models.model import IconType
|
||||
|
||||
|
||||
def test_parameters_model_round_trip():
|
||||
parameters = get_parameters_from_feature_dict(features_dict={}, user_input_form=[])
|
||||
|
||||
model = Parameters.model_validate(parameters)
|
||||
|
||||
assert model.model_dump(mode="json") == parameters
|
||||
|
||||
|
||||
def test_site_icon_url_uses_signed_url_for_image_icon():
|
||||
site = SimpleNamespace(
|
||||
title="Example",
|
||||
chat_color_theme=None,
|
||||
chat_color_theme_inverted=False,
|
||||
icon_type=IconType.IMAGE,
|
||||
icon="file-id",
|
||||
icon_background=None,
|
||||
description=None,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
default_language="en-US",
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
with patch("controllers.common.fields.file_helpers.get_signed_file_url", return_value="signed") as mock_helper:
|
||||
model = Site.model_validate(site)
|
||||
|
||||
assert model.icon_url == "signed"
|
||||
mock_helper.assert_called_once_with("file-id")
|
||||
|
||||
|
||||
def test_site_icon_url_is_none_for_non_image_icon():
|
||||
site = SimpleNamespace(
|
||||
title="Example",
|
||||
chat_color_theme=None,
|
||||
chat_color_theme_inverted=False,
|
||||
icon_type=IconType.EMOJI,
|
||||
icon="file-id",
|
||||
icon_background=None,
|
||||
description=None,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
default_language="en-US",
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
with patch("controllers.common.fields.file_helpers.get_signed_file_url") as mock_helper:
|
||||
model = Site.model_validate(site)
|
||||
|
||||
assert model.icon_url is None
|
||||
mock_helper.assert_not_called()
|
||||
Loading…
Reference in New Issue