fix type

fix
This commit is contained in:
Asuka Minato 2025-12-08 20:48:57 +09:00
parent 3225ca8337
commit 995bfe16fc
11 changed files with 51 additions and 50 deletions

View File

@ -3,15 +3,16 @@ from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from controllers.console import console_ns
from controllers.common.schema import register_schema_models
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required
class CodeBasedExtensionQuery(BaseModel): class CodeBasedExtensionQuery(BaseModel):
module: str module: str

View File

@ -1,8 +1,11 @@
from typing import Literal
from flask import request from flask import request
from flask_restx import Resource, marshal_with from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields from fields.tag_fields import dataset_tag_fields

View File

@ -24,6 +24,7 @@ from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value from libs.helper import alphanumeric, uuid_value
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -74,12 +75,12 @@ class BuiltinToolUpdatePayload(BaseModel):
class ApiToolProviderBasePayload(BaseModel): class ApiToolProviderBasePayload(BaseModel):
credentials: dict[str, Any] credentials: dict[str, Any]
schema_type: str schema_type: str
schema: str schema_: str = Field(alias="schema")
provider: str provider: str
icon: dict[str, Any] icon: dict[str, Any]
privacy_policy: str | None = None privacy_policy: str | None = None
labels: list[str] | None = None labels: list[str] | None = None
custom_disclaimer: str | None = None custom_disclaimer: str = ""
class ApiToolProviderAddPayload(ApiToolProviderBasePayload): class ApiToolProviderAddPayload(ApiToolProviderBasePayload):
@ -110,7 +111,7 @@ class ApiToolProviderDeletePayload(BaseModel):
class ApiToolSchemaPayload(BaseModel): class ApiToolSchemaPayload(BaseModel):
schema: str schema_: str = Field(alias="schema")
class ApiToolTestPayload(BaseModel): class ApiToolTestPayload(BaseModel):
@ -119,7 +120,7 @@ class ApiToolTestPayload(BaseModel):
credentials: dict[str, Any] credentials: dict[str, Any]
parameters: dict[str, Any] parameters: dict[str, Any]
schema_type: str schema_type: str
schema: str schema_: str = Field(alias="schema")
class WorkflowToolBasePayload(BaseModel): class WorkflowToolBasePayload(BaseModel):
@ -127,7 +128,7 @@ class WorkflowToolBasePayload(BaseModel):
label: str label: str
description: str description: str
icon: dict[str, Any] icon: dict[str, Any]
parameters: list[dict[str, Any]] parameters: list[WorkflowToolParameterConfiguration]
privacy_policy: str | None = "" privacy_policy: str | None = ""
labels: list[str] | None = None labels: list[str] | None = None
@ -205,7 +206,7 @@ class MCPProviderBasePayload(BaseModel):
name: str name: str
icon: str icon: str
icon_type: str icon_type: str
icon_background: str | None = "" icon_background: str = ""
server_identifier: str server_identifier: str
configuration: dict[str, Any] | None = Field(default_factory=dict) configuration: dict[str, Any] | None = Field(default_factory=dict)
headers: dict[str, Any] | None = Field(default_factory=dict) headers: dict[str, Any] | None = Field(default_factory=dict)
@ -266,10 +267,10 @@ class ToolProviderListApi(Resource):
user_id = user.id user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
query = ToolProviderListQuery.model_validate(raw_args) query = ToolProviderListQuery.model_validate(raw_args)
return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools") @console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
@ -411,7 +412,7 @@ class ToolApiProviderAddApi(Resource):
payload.icon, payload.icon,
payload.credentials, payload.credentials,
payload.schema_type, payload.schema_type,
payload.schema, payload.schema_,
payload.privacy_policy or "", payload.privacy_policy or "",
payload.custom_disclaimer or "", payload.custom_disclaimer or "",
payload.labels or [], payload.labels or [],
@ -428,7 +429,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
user_id = user.id user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
query = UrlQuery.model_validate(raw_args) query = UrlQuery.model_validate(raw_args)
return ApiToolManageService.get_api_tool_provider_remote_schema( return ApiToolManageService.get_api_tool_provider_remote_schema(
@ -448,7 +449,7 @@ class ToolApiProviderListToolsApi(Resource):
user_id = user.id user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
query = ProviderQuery.model_validate(raw_args) query = ProviderQuery.model_validate(raw_args)
return jsonable_encoder( return jsonable_encoder(
@ -482,7 +483,7 @@ class ToolApiProviderUpdateApi(Resource):
payload.icon, payload.icon,
payload.credentials, payload.credentials,
payload.schema_type, payload.schema_type,
payload.schema, payload.schema_,
payload.privacy_policy, payload.privacy_policy,
payload.custom_disclaimer, payload.custom_disclaimer,
payload.labels or [], payload.labels or [],
@ -520,7 +521,7 @@ class ToolApiProviderGetApi(Resource):
user_id = user.id user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
query = ProviderQuery.model_validate(raw_args) query = ProviderQuery.model_validate(raw_args)
return ApiToolManageService.get_api_tool_provider( return ApiToolManageService.get_api_tool_provider(
@ -555,7 +556,7 @@ class ToolApiProviderSchemaApi(Resource):
payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {}) payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {})
return ApiToolManageService.parser_api_schema( return ApiToolManageService.parser_api_schema(
schema=payload.schema, schema=payload.schema_,
) )
@ -575,7 +576,7 @@ class ToolApiProviderPreviousTestApi(Resource):
payload.credentials, payload.credentials,
payload.parameters, payload.parameters,
payload.schema_type, payload.schema_type,
payload.schema, payload.schema_,
) )
@ -665,7 +666,7 @@ class ToolWorkflowProviderGetApi(Resource):
user_id = user.id user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
query = WorkflowToolGetQuery.model_validate(raw_args) query = WorkflowToolGetQuery.model_validate(raw_args)
if query.workflow_tool_id: if query.workflow_tool_id:
@ -696,7 +697,7 @@ class ToolWorkflowProviderListToolApi(Resource):
user_id = user.id user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
query = WorkflowToolListQuery.model_validate(raw_args) query = WorkflowToolListQuery.model_validate(raw_args)
return jsonable_encoder( return jsonable_encoder(
@ -1155,7 +1156,7 @@ class ToolMCPUpdateApi(Resource):
@console_ns.route("/mcp/oauth/callback") @console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource): class ToolMCPCallbackApi(Resource):
def get(self): def get(self):
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
query = MCPCallbackQuery.model_validate(raw_args) query = MCPCallbackQuery.model_validate(raw_args)
state_key = query.state state_key = query.state
authorization_code = query.code authorization_code = query.code

View File

@ -7,9 +7,7 @@ from werkzeug.exceptions import Unauthorized
from constants import HEADER_NAME_APP_CODE from constants import HEADER_NAME_APP_CODE
from controllers.common import fields from controllers.common import fields
from controllers.web import web_ns from controllers.common.schema import register_schema_models
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from libs.passport import PassportService from libs.passport import PassportService
from libs.token import extract_webapp_passport from libs.token import extract_webapp_passport
@ -19,6 +17,10 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService from services.webapp_auth_service import WebAppAuthService
from . import web_ns
from .error import AppUnavailableError
from .wraps import WebApiResource
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,6 +31,9 @@ class AppAccessModeQuery(BaseModel):
app_code: str | None = Field(default=None, alias="appCode", description="Application code") app_code: str | None = Field(default=None, alias="appCode", description="Application code")
register_schema_models(web_ns, AppAccessModeQuery)
@web_ns.route("/parameters") @web_ns.route("/parameters")
class AppParameterApi(WebApiResource): class AppParameterApi(WebApiResource):
"""Resource for app variables.""" """Resource for app variables."""
@ -99,7 +104,7 @@ class AppAccessMode(Resource):
} }
) )
def get(self): def get(self):
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
args = AppAccessModeQuery.model_validate(raw_args) args = AppAccessModeQuery.model_validate(raw_args)
features = FeatureService.get_system_features() features = FeatureService.get_system_features()

View File

@ -18,8 +18,6 @@ from controllers.web.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
UnsupportedAudioTypeError, UnsupportedAudioTypeError,
) )
from ..common import register_schema_models
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
@ -32,6 +30,9 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError, UnsupportedAudioTypeServiceError,
) )
from ..common.schema import register_schema_models
from ..web.wraps import WebApiResource
class TextToAudioPayload(BaseModel): class TextToAudioPayload(BaseModel):
message_id: str | None = None message_id: str | None = None
@ -46,6 +47,7 @@ class TextToAudioPayload(BaseModel):
return value return value
return uuid_value(value) return uuid_value(value)
register_schema_models(web_ns, TextToAudioPayload) register_schema_models(web_ns, TextToAudioPayload)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -82,7 +82,7 @@ class ConversationListApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
query = ConversationListQuery.model_validate(raw_args) query = ConversationListQuery.model_validate(raw_args)
try: try:
@ -171,7 +171,7 @@ class ConversationRenameApi(WebApiResource):
payload = ConversationRenamePayload.model_validate(web_ns.payload or {}) payload = ConversationRenamePayload.model_validate(web_ns.payload or {})
try: try:
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) # type: ignore
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")

View File

@ -108,7 +108,7 @@ class MessageListApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
query = MessageListQuery.model_validate(raw_args) query = MessageListQuery.model_validate(raw_args)
try: try:
@ -182,7 +182,7 @@ class MessageMoreLikeThisApi(WebApiResource):
message_id = str(message_id) message_id = str(message_id)
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
query = MessageMoreLikeThisQuery.model_validate(raw_args) query = MessageMoreLikeThisQuery.model_validate(raw_args)
streaming = query.response_mode == "streaming" streaming = query.response_mode == "streaming"

View File

@ -85,7 +85,7 @@ class SavedMessageListApi(WebApiResource):
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type] raw_args = request.args.to_dict()
query = SavedMessageListQuery.model_validate(raw_args) query = SavedMessageListQuery.model_validate(raw_args)
return SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) return SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)

View File

@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
class WorkflowToolConfigurationUtils: class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod @classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
""" """

View File

@ -247,14 +247,14 @@ class ApiToolManageService:
credentials: dict, credentials: dict,
schema_type: str, schema_type: str,
schema: str, schema: str,
privacy_policy: str, privacy_policy: str | None,
custom_disclaimer: str, custom_disclaimer: str,
labels: list[str], labels: list[str],
): ):
""" """
update api tool provider update api tool provider
""" """
if schema_type not in [member.value for member in ApiProviderSchemaType]: if schema_type not in list(ApiProviderSchemaType):
raise ValueError(f"invalid schema type {schema}") raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip() provider_name = provider_name.strip()

View File

@ -1,8 +1,6 @@
import json import json
import logging import logging
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from typing import Any
from sqlalchemy import or_, select from sqlalchemy import or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -11,8 +9,8 @@ from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db from extensions.ext_database import db
@ -39,12 +37,10 @@ class WorkflowToolManageService:
label: str, label: str,
icon: dict, icon: dict,
description: str, description: str,
parameters: list[Mapping[str, Any]], parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "", privacy_policy: str = "",
labels: list[str] | None = None, labels: list[str] | None = None,
): ):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique # check if the name is unique
existing_workflow_tool_provider = ( existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
@ -77,7 +73,7 @@ class WorkflowToolManageService:
label=label, label=label,
icon=json.dumps(icon), icon=json.dumps(icon),
description=description, description=description,
parameter_configuration=json.dumps(parameters), parameter_configuration=json.dumps([p.model_dump() for p in parameters]),
privacy_policy=privacy_policy, privacy_policy=privacy_policy,
version=workflow.version, version=workflow.version,
) )
@ -108,7 +104,7 @@ class WorkflowToolManageService:
label: str, label: str,
icon: dict, icon: dict,
description: str, description: str,
parameters: list[Mapping[str, Any]], parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "", privacy_policy: str = "",
labels: list[str] | None = None, labels: list[str] | None = None,
): ):
@ -126,8 +122,6 @@ class WorkflowToolManageService:
:param labels: labels :param labels: labels
:return: the updated tool :return: the updated tool
""" """
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique # check if the name is unique
existing_workflow_tool_provider = ( existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
@ -166,7 +160,7 @@ class WorkflowToolManageService:
workflow_tool_provider.label = label workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon) workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps(parameters) workflow_tool_provider.parameter_configuration = json.dumps(obj=[p.model_dump() for p in parameters])
workflow_tool_provider.privacy_policy = privacy_policy workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now() workflow_tool_provider.updated_at = datetime.now()