diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index a32c3420bb..bb2f477e3d 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -21,8 +21,6 @@ from libs.token import extract_access_token from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp from services.billing_service import BillingService, LangContentDict -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class InsertExploreAppPayload(BaseModel): app_id: str = Field(...) @@ -59,15 +57,7 @@ class InsertExploreBannerPayload(BaseModel): model_config = {"populate_by_name": True} -console_ns.schema_model( - InsertExploreAppPayload.__name__, - InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - InsertExploreBannerPayload.__name__, - InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, InsertExploreAppPayload, InsertExploreBannerPayload) def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index ed66da1be5..ad21671176 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -34,7 +34,7 @@ class AdvancedPromptTemplateList(Resource): @login_required @account_initialization_required def get(self): - args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) prompt_args: AdvancedPromptTemplateArgs = { "app_mode": args.app_mode, "model_mode": args.model_mode, diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index cfdb9cf417..c05600ced5 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -2,6 +2,7 @@ from flask import request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required @@ -10,8 +11,6 @@ from libs.login import login_required from models.model import AppMode from services.agent_service import AgentService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AgentLogQuery(BaseModel): message_id: str = Field(..., description="Message UUID") @@ -23,9 +22,7 @@ class AgentLogQuery(BaseModel): return uuid_value(value) -console_ns.schema_model( - AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, AgentLogQuery) @console_ns.route("/apps//agent/logs") @@ -44,6 +41,6 @@ class AgentLogApi(Resource): @get_app_model(mode=[AppMode.AGENT_CHAT]) def get(self, app_model): """Get agent logs""" - args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 528785931e..5970e55285 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -33,8 +33,6 @@ from services.annotation_service import ( UpsertAnnotationArgs, ) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AnnotationReplyPayload(BaseModel): score_threshold: float = Field(..., description="Score threshold for annotation matching") @@ -87,17 +85,6 @@ class AnnotationFilePayload(BaseModel): return uuid_value(value) -def reg(model: type[BaseModel]) -> None: - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(AnnotationReplyPayload) -reg(AnnotationSettingUpdatePayload) -reg(AnnotationListQuery) -reg(CreateAnnotationPayload) -reg(UpdateAnnotationPayload) -reg(AnnotationReplyStatusQuery) -reg(AnnotationFilePayload) register_schema_models( console_ns, Annotation, @@ -105,6 +92,13 @@ register_schema_models( AnnotationExportList, AnnotationHitHistory, AnnotationHitHistoryList, + AnnotationReplyPayload, + AnnotationSettingUpdatePayload, + AnnotationListQuery, + CreateAnnotationPayload, + UpdateAnnotationPayload, + AnnotationReplyStatusQuery, + AnnotationFilePayload, ) @@ -218,7 +212,7 @@ class AnnotationApi(Resource): @account_initialization_required @edit_permission_required def get(self, app_id): - args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page limit = args.limit keyword = args.keyword diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 58ed6efc14..5023d46893 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -701,7 +701,7 @@ class AppExportApi(Resource): @edit_permission_required def get(self, app_model): """Export app""" - args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) payload = AppExportResponse( data=AppDslService.export_dsl( diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 91fbe4a85a..5b673f3394 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -173,7 +173,7 @@ class TextModesApi(Resource): @account_initialization_required def get(self, app_model): try: - args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) response = AudioService.transcript_tts_voices( tenant_id=app_model.tenant_id, diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index fe274e4c9a..6a20296cff 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -37,7 +38,6 @@ from services.app_task_service import AppTaskService from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class BaseMessagePayload(BaseModel): @@ -65,13 +65,7 @@ class ChatMessagePayload(BaseMessagePayload): return uuid_value(value) -console_ns.schema_model( - CompletionMessagePayload.__name__, - CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) # define completion message api for user diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index b2b1049f0c..c7347933cb 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -39,8 +39,6 @@ from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class BaseConversationQuery(BaseModel): keyword: str | None = Field(default=None, description="Search keyword") @@ -70,15 +68,6 @@ class ChatConversationQuery(BaseConversationQuery): ) -console_ns.schema_model( - CompletionConversationQuery.__name__, - CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ChatConversationQuery.__name__, - ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - register_schema_models( console_ns, CompletionConversationQuery, @@ -89,6 +78,8 @@ register_schema_models( ConversationWithSummaryPaginationResponse, ConversationDetailResponse, ResultResponse, + CompletionConversationQuery, + ChatConversationQuery, ) @@ -107,7 +98,7 @@ class CompletionConversationApi(Resource): @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() - args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) query = sa.select(Conversation).where( Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) @@ -221,7 +212,7 @@ class ChatConversationApi(Resource): @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() - args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) subquery = ( sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 9c8b095b9f..60a2bfc799 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -100,7 +100,7 @@ class ConversationVariablesApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.ADVANCED_CHAT) def get(self, app_model): - args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) stmt = ( select(ConversationVariable) diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index cbcf513162..ee2fc39f86 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -5,14 +5,13 @@ from flask_restx import Resource, fields from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.ops_service import OpsService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class TraceProviderQuery(BaseModel): tracing_provider: str = Field(..., description="Tracing provider name") @@ -23,13 +22,7 @@ class TraceConfigPayload(BaseModel): tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data") -console_ns.schema_model( - TraceProviderQuery.__name__, - TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, TraceProviderQuery, TraceConfigPayload) @console_ns.route("/apps//trace-config") @@ -50,7 +43,7 @@ class TraceAppConfigApi(Resource): @login_required @account_initialization_required def get(self, app_id): - args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) try: trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) @@ -121,7 +114,7 @@ class TraceAppConfigApi(Resource): @account_initialization_required def delete(self, app_id): """Delete an existing trace app configuration""" - args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) try: result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index ffa28b1c95..d23b2837c9 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -5,6 +5,7 @@ from flask import abort, jsonify, request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required @@ -15,8 +16,6 @@ from libs.helper import convert_datetime_to_date from libs.login import current_account_with_tenant, login_required from models import AppMode -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class StatisticTimeRangeQuery(BaseModel): start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)") @@ -30,10 +29,7 @@ class StatisticTimeRangeQuery(BaseModel): return value -console_ns.schema_model( - StatisticTimeRangeQuery.__name__, - StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, StatisticTimeRangeQuery) @console_ns.route("/apps//statistics/daily-messages") @@ -54,7 +50,7 @@ class DailyMessageStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -111,7 +107,7 @@ class DailyConversationStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -167,7 +163,7 @@ class DailyTerminalsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -224,7 +220,7 @@ class DailyTokenCostStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -284,7 +280,7 @@ class AverageSessionInteractionStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("c.created_at") sql_query = f"""SELECT @@ -360,7 +356,7 @@ class UserSatisfactionRateStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("m.created_at") sql_query = f"""SELECT @@ -426,7 +422,7 @@ class AverageResponseTimeStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -482,7 +478,7 @@ class TokensPerSecondStatistic(Resource): @account_initialization_required def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index e18688f069..4f532b437c 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -60,7 +60,7 @@ logger = logging.getLogger(__name__) _file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000 WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50 @@ -912,7 +912,7 @@ class DefaultBlockConfigApi(Resource): """ Get default block config """ - args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) filters = None if args.q: @@ -1005,7 +1005,7 @@ class PublishedAllWorkflowApi(Resource): """ current_user, _ = current_account_with_tenant() - args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page limit = args.limit user_id = args.user_id diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 4b39590235..ddc900eb2d 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -185,7 +185,7 @@ class WorkflowAppLogApi(Resource): """ Get workflow app logs """ - args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # get paginate workflow app logs workflow_app_service = WorkflowAppService() @@ -228,7 +228,7 @@ class WorkflowArchivedLogApi(Resource): """ Get workflow archived logs """ - args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) workflow_app_service = WorkflowAppService() with sessionmaker(db.engine, expire_on_commit=False).begin() as session: diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py index e7c3e982a6..c003be1303 100644 --- a/api/controllers/console/app/workflow_comment.py +++ b/api/controllers/console/app/workflow_comment.py @@ -23,7 +23,6 @@ from services.account_service import TenantService from services.workflow_comment_service import WorkflowCommentService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class WorkflowCommentCreatePayload(BaseModel): @@ -52,13 +51,14 @@ class WorkflowCommentMentionUsersPayload(BaseModel): users: list[AccountWithRole] -for model in ( +register_schema_models( + console_ns, + AccountWithRole, + WorkflowCommentMentionUsersPayload, WorkflowCommentCreatePayload, WorkflowCommentUpdatePayload, WorkflowCommentReplyPayload, -): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload) +) workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields) workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index c688a69074..3c887c33dc 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -8,6 +8,7 @@ from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( DraftWorkflowNotExist, @@ -33,7 +34,6 @@ from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) _file_access_controller = DatabaseFileAccessController() -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class WorkflowDraftVariableListQuery(BaseModel): @@ -56,21 +56,12 @@ class EnvironmentVariableUpdatePayload(BaseModel): environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow") -console_ns.schema_model( - WorkflowDraftVariableListQuery.__name__, - WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - WorkflowDraftVariableUpdatePayload.__name__, - WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ConversationVariableUpdatePayload.__name__, - ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - EnvironmentVariableUpdatePayload.__name__, - EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +register_schema_models( + console_ns, + WorkflowDraftVariableListQuery, + WorkflowDraftVariableUpdatePayload, + ConversationVariableUpdatePayload, + EnvironmentVariableUpdatePayload, ) @@ -260,7 +251,7 @@ class WorkflowVariableCollectionApi(Resource): """ Get draft workflow """ - args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # fetch draft workflow by app_model workflow_service = WorkflowService() diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index e42aae6090..97d2003209 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -154,7 +154,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource): """ Get advanced chat app workflow run list """ - args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) args: WorkflowRunListArgs = {"limit": args_model.limit} if args_model.last_id is not None: args["last_id"] = args_model.last_id @@ -250,7 +250,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource): """ Get advanced chat workflow runs count statistics """ - args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) args = args_model.model_dump(exclude_none=True) # Default to DEBUGGING if not specified @@ -290,7 +290,7 @@ class WorkflowRunListApi(Resource): """ Get workflow run list """ - args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) args: WorkflowRunListArgs = {"limit": args_model.limit} if args_model.last_id is not None: args["last_id"] = args_model.last_id @@ -331,7 +331,7 @@ class WorkflowRunCountApi(Resource): """ Get workflow runs count statistics """ - args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) args = args_model.model_dump(exclude_none=True) # Default to DEBUGGING for workflow if not specified (backward compatibility) diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index e48cf42762..ca899d8784 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -3,6 +3,7 @@ from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required @@ -13,8 +14,6 @@ from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class WorkflowStatisticQuery(BaseModel): start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)") @@ -28,10 +27,7 @@ class WorkflowStatisticQuery(BaseModel): return value -console_ns.schema_model( - WorkflowStatisticQuery.__name__, - WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, WorkflowStatisticQuery) @console_ns.route("/apps//workflow/statistics/daily-conversations") @@ -53,7 +49,7 @@ class WorkflowDailyRunsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -93,7 +89,7 @@ class WorkflowDailyTerminalsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -133,7 +129,7 @@ class WorkflowDailyTokenCostStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -173,7 +169,7 @@ class WorkflowAverageAppInteractionStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index a6715fa200..a80b4f5d0c 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -94,7 +94,7 @@ class WebhookTriggerApi(Resource): @console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__]) def get(self, app_model: App): """Get webhook trigger for a node""" - args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = Parser.model_validate(request.args.to_dict(flat=True)) node_id = args.node_id diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index f7061f820f..0c05cf2fe3 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -63,7 +63,7 @@ class ActivateCheckApi(Resource): console_ns.models[ActivationCheckResponse.__name__], ) def get(self): - args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) workspaceId = args.workspace_id token = args.token diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 905d0daef0..db0d36af6e 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,6 +1,7 @@ from flask_restx import Resource from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from libs.login import current_account_with_tenant, login_required from services.auth.api_key_auth_service import ApiKeyAuthService @@ -8,8 +9,6 @@ from .. import console_ns from ..auth.error import ApiKeyAuthFailedError from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ApiKeyAuthBindingPayload(BaseModel): category: str = Field(...) @@ -17,10 +16,7 @@ class ApiKeyAuthBindingPayload(BaseModel): credentials: dict = Field(...) -console_ns.schema_model( - ApiKeyAuthBindingPayload.__name__, - ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, ApiKeyAuthBindingPayload) @console_ns.route("/api-key-auth/data-source") diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 1fd781b4fc..f6b8aedf22 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field, field_validator from configs import dify_config from constants.languages import languages +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailAlreadyInUseError, @@ -23,8 +24,6 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError from ..error import AccountInFreezeError, EmailSendIpLimitError from ..wraps import email_password_login_enabled, email_register_enabled, setup_required -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class EmailRegisterSendPayload(BaseModel): email: EmailStr = Field(..., description="Email address") @@ -48,8 +47,7 @@ class EmailRegisterResetPayload(BaseModel): return valid_password(value) -for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload) @console_ns.route("/email-register/send-email") diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index ed390a5f89..c34dd1ac85 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -28,8 +28,6 @@ from services.entities.auth_entities import ( ) from services.feature_service import FeatureService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ForgotPasswordEmailResponse(BaseModel): result: str = Field(description="Operation result") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 8216b3d0da..19c98f3a1a 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -9,6 +9,7 @@ from werkzeug.exceptions import Unauthorized import services from configs import dify_config from constants.languages import get_valid_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( AuthenticationFailedError, @@ -50,7 +51,6 @@ from services.errors.account import AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" logger = logging.getLogger(__name__) @@ -71,13 +71,7 @@ class EmailCodeLoginPayload(BaseModel): language: str | None = Field(default=None) -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(LoginPayload) -reg(EmailPayload) -reg(EmailCodeLoginPayload) +register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload) @console_ns.route("/login") diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 7caf5b52ed..a43caa8f56 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -4,6 +4,7 @@ from flask_restx import ( # type: ignore from pydantic import BaseModel from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required @@ -12,8 +13,6 @@ from models import Account from models.dataset import Pipeline from services.rag_pipeline.rag_pipeline import RagPipelineService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class Parser(BaseModel): inputs: dict @@ -21,7 +20,7 @@ class Parser(BaseModel): credential_id: str | None = None -console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, Parser) @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//preview") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 572f9773a1..bd0e875666 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -80,7 +80,7 @@ class RecommendedAppListApi(Resource): @account_initialization_required def get(self): # language args - args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) language = args.language if language and language in languages: language_prefix = language diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 1456301a24..025c517d20 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -10,7 +10,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from controllers.common.fields import Parameters as ParametersResponse from controllers.common.fields import Site as SiteResponse -from controllers.common.schema import get_or_create_model +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -120,10 +120,6 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy) -# Pydantic models for request validation -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - class WorkflowRunRequest(BaseModel): inputs: dict files: list | None = None @@ -153,19 +149,7 @@ class CompletionRequest(BaseModel): retriever_from: str = "explore_app" -# Register schemas for Swagger documentation -console_ns.schema_model( - WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, WorkflowRunRequest, ChatRequest, TextToSpeechRequest, CompletionRequest) class TrialAppWorkflowRunApi(TrialAppResource): diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 7a6356d052..9ffc18e4c2 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -89,7 +89,7 @@ class CodeBasedExtensionAPI(Resource): @login_required @account_initialization_required def get(self): - query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) return CodeBasedExtensionResponse( module=query.module, diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index d69a59ecb7..68520e540b 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -52,8 +52,6 @@ from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AccountInitPayload(BaseModel): interface_language: str @@ -161,27 +159,26 @@ class CheckEmailUniquePayload(BaseModel): email: EmailStr -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(AccountInitPayload) -reg(AccountNamePayload) -reg(AccountAvatarPayload) -reg(AccountAvatarQuery) -reg(AccountInterfaceLanguagePayload) -reg(AccountInterfaceThemePayload) -reg(AccountTimezonePayload) -reg(AccountPasswordPayload) -reg(AccountDeletePayload) -reg(AccountDeletionFeedbackPayload) -reg(EducationActivatePayload) -reg(EducationAutocompleteQuery) -reg(ChangeEmailSendPayload) -reg(ChangeEmailValidityPayload) -reg(ChangeEmailResetPayload) -reg(CheckEmailUniquePayload) -register_schema_models(console_ns, AccountResponse) +register_schema_models( + console_ns, + AccountResponse, + AccountInitPayload, + AccountNamePayload, + AccountAvatarPayload, + AccountAvatarQuery, + AccountInterfaceLanguagePayload, + AccountInterfaceThemePayload, + AccountTimezonePayload, + AccountPasswordPayload, + AccountDeletePayload, + AccountDeletionFeedbackPayload, + EducationActivatePayload, + EducationAutocompleteQuery, + ChangeEmailSendPayload, + ChangeEmailValidityPayload, + ChangeEmailResetPayload, + CheckEmailUniquePayload, +) def _serialize_account(account) -> dict[str, Any]: @@ -326,7 +323,7 @@ class AccountAvatarApi(Resource): @account_initialization_required def get(self): current_user, current_tenant_id = current_account_with_tenant() - args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) avatar = args.avatar if avatar.startswith(("http://", "https://")): diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index d4be07382a..925f3e1197 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -20,8 +20,6 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class EndpointCreatePayload(BaseModel): plugin_unique_identifier: str @@ -80,10 +78,6 @@ class EndpointDisableResponse(BaseModel): success: bool = Field(description="Operation success") -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - register_schema_models( console_ns, EndpointCreatePayload, @@ -215,7 +209,7 @@ class EndpointListApi(Resource): def get(self): user, tenant_id = current_account_with_tenant() - args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page page_size = args.page_size @@ -248,7 +242,7 @@ class EndpointListForSinglePluginApi(Resource): def get(self): user, tenant_id = current_account_with_tenant() - args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) page = args.page page_size = args.page_size diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index e3bf4c95b8..c2533c9872 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -33,8 +33,6 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.errors.account import AccountAlreadyInTenantError from services.feature_service import FeatureService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class MemberInvitePayload(BaseModel): emails: list[str] = Field(default_factory=list) @@ -59,17 +57,17 @@ class OwnerTransferPayload(BaseModel): token: str -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(MemberInvitePayload) -reg(MemberRoleUpdatePayload) -reg(OwnerTransferEmailPayload) -reg(OwnerTransferCheckPayload) -reg(OwnerTransferPayload) register_enum_models(console_ns, TenantAccountRole) -register_schema_models(console_ns, AccountWithRole, AccountWithRoleList) +register_schema_models( + console_ns, + AccountWithRole, + AccountWithRoleList, + MemberInvitePayload, + MemberRoleUpdatePayload, + OwnerTransferEmailPayload, + OwnerTransferCheckPayload, + OwnerTransferPayload, +) @console_ns.route("/workspaces/current/members") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 4b10561fdb..2f75218c0f 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -5,6 +5,7 @@ from flask import request, send_file from flask_restx import Resource from pydantic import BaseModel, Field, field_validator +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from graphon.model_runtime.entities.model_entities import ModelType @@ -15,8 +16,6 @@ from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ParserModelList(BaseModel): model_type: ModelType | None = None @@ -75,18 +74,17 @@ class ParserPreferredProviderType(BaseModel): preferred_provider_type: Literal["system", "custom"] -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(ParserModelList) -reg(ParserCredentialId) -reg(ParserCredentialCreate) -reg(ParserCredentialUpdate) -reg(ParserCredentialDelete) -reg(ParserCredentialSwitch) -reg(ParserCredentialValidate) -reg(ParserPreferredProviderType) +register_schema_models( + console_ns, + ParserModelList, + ParserCredentialId, + ParserCredentialCreate, + ParserCredentialUpdate, + ParserCredentialDelete, + ParserCredentialSwitch, + ParserCredentialValidate, + ParserPreferredProviderType, +) @console_ns.route("/workspaces/current/model-providers") diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index b2d07ff8f9..7f7d6379c3 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -17,7 +17,6 @@ from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class ParserGetDefault(BaseModel): @@ -107,6 +106,12 @@ class ParserParameter(BaseModel): model: str +class ParserSwitch(BaseModel): + model: str + model_type: ModelType + credential_id: str + + register_schema_models( console_ns, ParserGetDefault, @@ -119,6 +124,7 @@ register_schema_models( ParserDeleteCredential, ParserParameter, Inner, + ParserSwitch, ) register_enum_models(console_ns, ModelType) @@ -133,7 +139,7 @@ class DefaultModelApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( @@ -261,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource): def get(self, provider: str): _, tenant_id = current_account_with_tenant() - args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) model_provider_service = ModelProviderService() current_credential = model_provider_service.get_model_credential( @@ -387,17 +393,6 @@ class ModelProviderModelCredentialApi(Resource): return {"result": "success"}, 204 -class ParserSwitch(BaseModel): - model: str - model_type: ModelType - credential_id: str - - -console_ns.schema_model( - ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - - @console_ns.route("/workspaces/current/model-providers//models/credentials/switch") class ModelProviderModelCredentialSwitchApi(Resource): @console_ns.expect(console_ns.models[ParserSwitch.__name__]) @@ -468,9 +463,7 @@ class ParserValidate(BaseModel): credentials: dict[str, Any] -console_ns.schema_model( - ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, ParserSwitch, ParserValidate) @console_ns.route("/workspaces/current/model-providers//models/credentials/validate") @@ -515,7 +508,7 @@ class ModelProviderModelParameterRuleApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserParameter.model_validate(request.args.to_dict(flat=True)) _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index b3e344ccea..93e7f3acab 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -211,7 +211,7 @@ class PluginListApi(Resource): @account_initialization_required def get(self): _, tenant_id = current_account_with_tenant() - args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserList.model_validate(request.args.to_dict(flat=True)) try: plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) except PluginDaemonClientSideError as e: @@ -261,7 +261,7 @@ class PluginIconApi(Resource): @console_ns.expect(console_ns.models[ParserIcon.__name__]) @setup_required def get(self): - args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserIcon.model_validate(request.args.to_dict(flat=True)) try: icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename) @@ -279,7 +279,7 @@ class PluginAssetApi(Resource): @login_required @account_initialization_required def get(self): - args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserAsset.model_validate(request.args.to_dict(flat=True)) _, tenant_id = current_account_with_tenant() try: @@ -421,7 +421,7 @@ class PluginFetchMarketplacePkgApi(Resource): @plugin_permission_required(install_required=True) def get(self): _, tenant_id = current_account_with_tenant() - args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) try: return jsonable_encoder( @@ -446,7 +446,7 @@ class PluginFetchManifestApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) try: return jsonable_encoder( @@ -466,7 +466,7 @@ class PluginFetchInstallTasksApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserTasks.model_validate(request.args.to_dict(flat=True)) try: return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)}) @@ -660,7 +660,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource): current_user, tenant_id = current_account_with_tenant() user_id = current_user.id - args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) try: options = PluginParameterService.get_dynamic_select_options( @@ -822,7 +822,7 @@ class PluginReadmeApi(Resource): @account_initialization_required def get(self): _, tenant_id = current_account_with_tenant() - args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserReadme.model_validate(request.args.to_dict(flat=True)) return jsonable_encoder( {"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)} ) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 565099db61..a15d8b5918 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -16,6 +16,7 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.admin import admin_required from controllers.console.error import AccountNotLinkTenantError @@ -39,7 +40,6 @@ from services.file_service import FileService from services.workspace_service import WorkspaceService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class WorkspaceListQuery(BaseModel): @@ -91,15 +91,14 @@ class TenantInfoResponse(ResponseModel): return value -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(WorkspaceListQuery) -reg(SwitchWorkspacePayload) -reg(WorkspaceCustomConfigPayload) -reg(WorkspaceInfoPayload) -reg(TenantInfoResponse) +register_schema_models( + console_ns, + WorkspaceListQuery, + SwitchWorkspacePayload, + WorkspaceCustomConfigPayload, + WorkspaceInfoPayload, + TenantInfoResponse, +) provider_fields = { "provider_name": fields.String, diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index a91e745f80..be7886e831 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -8,13 +8,12 @@ from werkzeug.exceptions import NotFound import services from controllers.common.errors import UnsupportedFileTypeError from controllers.common.file_response import enforce_download_for_html +from controllers.common.schema import register_schema_models from controllers.files import files_ns from extensions.ext_database import db from services.account_service import TenantService from services.file_service import FileService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class FileSignatureQuery(BaseModel): timestamp: str = Field(..., description="Unix timestamp used in the signature") @@ -26,12 +25,7 @@ class FilePreviewQuery(FileSignatureQuery): as_attachment: bool = Field(default=False, description="Whether to download as attachment") -files_ns.schema_model( - FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -files_ns.schema_model( - FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(files_ns, FileSignatureQuery, FilePreviewQuery) @files_ns.route("//image-preview") @@ -58,7 +52,7 @@ class ImagePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) timestamp = args.timestamp nonce = args.nonce sign = args.sign @@ -100,7 +94,7 @@ class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) try: generator, upload_file = FileService(db.engine).get_file_generator_by_file_id( diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 2f1e2f28bd..8ae16ce7f4 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -7,12 +7,11 @@ from werkzeug.exceptions import Forbidden, NotFound from controllers.common.errors import UnsupportedFileTypeError from controllers.common.file_response import enforce_download_for_html +from controllers.common.schema import register_schema_models from controllers.files import files_ns from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ToolFileQuery(BaseModel): timestamp: str = Field(..., description="Unix timestamp") @@ -21,9 +20,7 @@ class ToolFileQuery(BaseModel): as_attachment: bool = Field(default=False, description="Download as attachment") -files_ns.schema_model( - ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(files_ns, ToolFileQuery) @files_ns.route("/tools/.") diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index ed3278a28b..462e9ef58e 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -20,8 +20,6 @@ from ..console.wraps import setup_required from ..files import files_ns from ..inner_api.plugin.wraps import get_user -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class PluginUploadQuery(BaseModel): timestamp: str = Field(..., description="Unix timestamp for signature verification") @@ -31,9 +29,8 @@ class PluginUploadQuery(BaseModel): user_id: str | None = Field(default=None, description="User identifier") -files_ns.schema_model( - PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(files_ns, PluginUploadQuery) + register_schema_models(files_ns, FileResponse) @@ -69,7 +66,7 @@ class PluginUploadFileApi(Resource): FileTooLargeError: File exceeds size limit UnsupportedFileTypeError: File type not supported """ - args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) file = request.files.get("file") if file is None: diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 3eb773fa7c..9af66f1960 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_valid from werkzeug.exceptions import Forbidden, NotFound import services -from controllers.common.schema import register_schema_models +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError @@ -34,13 +34,7 @@ from services.tag_service import ( UpdateTagPayload, ) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - -service_api_ns.schema_model( - DatasetPermissionEnum.__name__, - TypeAdapter(DatasetPermissionEnum).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_enum_models(service_api_ns, DatasetPermissionEnum) class DatasetCreatePayload(BaseModel): diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 0b09facf58..1cf757912f 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -77,9 +77,6 @@ class DocumentTextCreatePayload(BaseModel): return value -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - class DocumentTextUpdate(BaseModel): name: str | None = None text: str | None = None