diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index c2cb4578ea..e333f5b129 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -3,15 +3,16 @@ from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field from constants import HIDDEN_VALUE -from controllers.console import console_ns -from controllers.common.schema import register_schema_models -from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import current_account_with_tenant, login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService +from ..common.schema import register_schema_models +from . import console_ns +from .wraps import account_initialization_required, setup_required + class CodeBasedExtensionQuery(BaseModel): module: str diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index c545348cde..108fa88d05 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,8 +1,11 @@ +from typing import Literal + from flask import request from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from fields.tag_fields import dataset_tag_fields diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index d9c1220b78..8fa52608de 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -24,6 +24,7 @@ from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required @@ -74,12 +75,12 @@ class BuiltinToolUpdatePayload(BaseModel): class ApiToolProviderBasePayload(BaseModel): credentials: dict[str, Any] schema_type: str - schema: str + schema_: str = Field(alias="schema") provider: str icon: dict[str, Any] privacy_policy: str | None = None labels: list[str] | None = None - custom_disclaimer: str | None = None + custom_disclaimer: str = "" class ApiToolProviderAddPayload(ApiToolProviderBasePayload): @@ -110,7 +111,7 @@ class ApiToolProviderDeletePayload(BaseModel): class ApiToolSchemaPayload(BaseModel): - schema: str + schema_: str = Field(alias="schema") class ApiToolTestPayload(BaseModel): @@ -119,7 +120,7 @@ class ApiToolTestPayload(BaseModel): credentials: dict[str, Any] parameters: dict[str, Any] schema_type: str - schema: str + schema_: str = Field(alias="schema") class WorkflowToolBasePayload(BaseModel): @@ -127,7 +128,7 @@ class WorkflowToolBasePayload(BaseModel): label: str description: str icon: dict[str, Any] - parameters: list[dict[str, Any]] + parameters: list[WorkflowToolParameterConfiguration] privacy_policy: str | None = "" labels: list[str] | None = None @@ -205,7 +206,7 @@ class MCPProviderBasePayload(BaseModel): name: str icon: str icon_type: str - icon_background: str | None = "" + icon_background: str = "" server_identifier: str configuration: dict[str, Any] | None = Field(default_factory=dict) headers: dict[str, Any] | None = Field(default_factory=dict) @@ -266,10 +267,10 @@ class ToolProviderListApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = ToolProviderListQuery.model_validate(raw_args) - return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) + return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore @console_ns.route("/workspaces/current/tool-provider/builtin//tools") @@ -411,7 +412,7 @@ class ToolApiProviderAddApi(Resource): payload.icon, payload.credentials, payload.schema_type, - payload.schema, + payload.schema_, payload.privacy_policy or "", payload.custom_disclaimer or "", payload.labels or [], @@ -428,7 +429,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = UrlQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider_remote_schema( @@ -448,7 +449,7 @@ class ToolApiProviderListToolsApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = ProviderQuery.model_validate(raw_args) return jsonable_encoder( @@ -482,7 +483,7 @@ class ToolApiProviderUpdateApi(Resource): payload.icon, payload.credentials, payload.schema_type, - payload.schema, + payload.schema_, payload.privacy_policy, payload.custom_disclaimer, payload.labels or [], @@ -520,7 +521,7 @@ class ToolApiProviderGetApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = ProviderQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider( @@ -555,7 +556,7 @@ class ToolApiProviderSchemaApi(Resource): payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.parser_api_schema( - schema=payload.schema, + schema=payload.schema_, ) @@ -575,7 +576,7 @@ class ToolApiProviderPreviousTestApi(Resource): payload.credentials, payload.parameters, payload.schema_type, - payload.schema, + payload.schema_, ) @@ -665,7 +666,7 @@ class ToolWorkflowProviderGetApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = WorkflowToolGetQuery.model_validate(raw_args) if query.workflow_tool_id: @@ -696,7 +697,7 @@ class ToolWorkflowProviderListToolApi(Resource): user_id = user.id - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = WorkflowToolListQuery.model_validate(raw_args) return jsonable_encoder( @@ -1155,7 +1156,7 @@ class ToolMCPUpdateApi(Resource): @console_ns.route("/mcp/oauth/callback") class ToolMCPCallbackApi(Resource): def get(self): - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = MCPCallbackQuery.model_validate(raw_args) state_key = query.state authorization_code = query.code diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 9721003eab..b8eefcb1e0 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -7,9 +7,7 @@ from werkzeug.exceptions import Unauthorized from constants import HEADER_NAME_APP_CODE from controllers.common import fields -from controllers.web import web_ns -from controllers.web.error import AppUnavailableError -from controllers.web.wraps import WebApiResource +from controllers.common.schema import register_schema_models from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from libs.passport import PassportService from libs.token import extract_webapp_passport @@ -19,6 +17,10 @@ from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.webapp_auth_service import WebAppAuthService +from . import web_ns +from .error import AppUnavailableError +from .wraps import WebApiResource + logger = logging.getLogger(__name__) @@ -29,6 +31,9 @@ class AppAccessModeQuery(BaseModel): app_code: str | None = Field(default=None, alias="appCode", description="Application code") +register_schema_models(web_ns, AppAccessModeQuery) + + @web_ns.route("/parameters") class AppParameterApi(WebApiResource): """Resource for app variables.""" @@ -99,7 +104,7 @@ class AppAccessMode(Resource): } ) def get(self): - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() args = AppAccessModeQuery.model_validate(raw_args) features = FeatureService.get_system_features() diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 30bd07847e..af31fbc334 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -18,8 +18,6 @@ from controllers.web.error import ( ProviderQuotaExceededError, UnsupportedAudioTypeError, ) -from ..common import register_schema_models -from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value @@ -32,6 +30,9 @@ from services.errors.audio import ( UnsupportedAudioTypeServiceError, ) +from ..common.schema import register_schema_models +from ..web.wraps import WebApiResource + class TextToAudioPayload(BaseModel): message_id: str | None = None @@ -46,6 +47,7 @@ class TextToAudioPayload(BaseModel): return value return uuid_value(value) + register_schema_models(web_ns, TextToAudioPayload) logger = logging.getLogger(__name__) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 3281e0b001..21f48ebb36 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -82,7 +82,7 @@ class ConversationListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = ConversationListQuery.model_validate(raw_args) try: @@ -171,7 +171,7 @@ class ConversationRenameApi(WebApiResource): payload = ConversationRenamePayload.model_validate(web_ns.payload or {}) try: - return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) + return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) # type: ignore except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index fdd9352972..1a7327ef26 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -108,7 +108,7 @@ class MessageListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = MessageListQuery.model_validate(raw_args) try: @@ -182,7 +182,7 @@ class MessageMoreLikeThisApi(WebApiResource): message_id = str(message_id) - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = MessageMoreLikeThisQuery.model_validate(raw_args) streaming = query.response_mode == "streaming" diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index f407507883..4d3b190f7c 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -85,7 +85,7 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] + raw_args = request.args.to_dict() query = SavedMessageListQuery.model_validate(raw_args) return SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 188da0c32d..6d75df3603 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity class WorkflowToolConfigurationUtils: - @classmethod - def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): - for configuration in configurations: - WorkflowToolParameterConfiguration.model_validate(configuration) - @classmethod def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: """ diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index b3b6e36346..82570afca1 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -247,14 +247,14 @@ class ApiToolManageService: credentials: dict, schema_type: str, schema: str, - privacy_policy: str, + privacy_policy: str | None, custom_disclaimer: str, labels: list[str], ): """ update api tool provider """ - if schema_type not in [member.value for member in ApiProviderSchemaType]: + if schema_type not in list(ApiProviderSchemaType): raise ValueError(f"invalid schema type {schema}") provider_name = provider_name.strip() diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index fe77ff2dc5..4f2ac94e91 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,8 +1,6 @@ import json import logging -from collections.abc import Mapping from datetime import datetime -from typing import Any from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -11,8 +9,8 @@ from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db @@ -39,12 +37,10 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -77,7 +73,7 @@ class WorkflowToolManageService: label=label, icon=json.dumps(icon), description=description, - parameter_configuration=json.dumps(parameters), + parameter_configuration=json.dumps([p.model_dump() for p in parameters]), privacy_policy=privacy_policy, version=workflow.version, ) @@ -108,7 +104,7 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): @@ -126,8 +122,6 @@ class WorkflowToolManageService: :param labels: labels :return: the updated tool """ - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -166,7 +160,7 @@ class WorkflowToolManageService: workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) workflow_tool_provider.description = description - workflow_tool_provider.parameter_configuration = json.dumps(parameters) + workflow_tool_provider.parameter_configuration = json.dumps(obj=[p.model_dump() for p in parameters]) workflow_tool_provider.privacy_policy = privacy_policy workflow_tool_provider.version = workflow.version workflow_tool_provider.updated_at = datetime.now()