fix type

fix
This commit is contained in:
Asuka Minato 2025-12-08 20:48:57 +09:00
parent 3225ca8337
commit 995bfe16fc
11 changed files with 51 additions and 50 deletions

View File

@ -3,15 +3,16 @@ from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants import HIDDEN_VALUE
from controllers.console import console_ns
from controllers.common.schema import register_schema_models
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required
class CodeBasedExtensionQuery(BaseModel):
module: str

View File

@ -1,8 +1,11 @@
from typing import Literal
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields

View File

@ -24,6 +24,7 @@ from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from libs.login import current_account_with_tenant, login_required
@ -74,12 +75,12 @@ class BuiltinToolUpdatePayload(BaseModel):
class ApiToolProviderBasePayload(BaseModel):
credentials: dict[str, Any]
schema_type: str
schema: str
schema_: str = Field(alias="schema")
provider: str
icon: dict[str, Any]
privacy_policy: str | None = None
labels: list[str] | None = None
custom_disclaimer: str | None = None
custom_disclaimer: str = ""
class ApiToolProviderAddPayload(ApiToolProviderBasePayload):
@ -110,7 +111,7 @@ class ApiToolProviderDeletePayload(BaseModel):
class ApiToolSchemaPayload(BaseModel):
schema: str
schema_: str = Field(alias="schema")
class ApiToolTestPayload(BaseModel):
@ -119,7 +120,7 @@ class ApiToolTestPayload(BaseModel):
credentials: dict[str, Any]
parameters: dict[str, Any]
schema_type: str
schema: str
schema_: str = Field(alias="schema")
class WorkflowToolBasePayload(BaseModel):
@ -127,7 +128,7 @@ class WorkflowToolBasePayload(BaseModel):
label: str
description: str
icon: dict[str, Any]
parameters: list[dict[str, Any]]
parameters: list[WorkflowToolParameterConfiguration]
privacy_policy: str | None = ""
labels: list[str] | None = None
@ -205,7 +206,7 @@ class MCPProviderBasePayload(BaseModel):
name: str
icon: str
icon_type: str
icon_background: str | None = ""
icon_background: str = ""
server_identifier: str
configuration: dict[str, Any] | None = Field(default_factory=dict)
headers: dict[str, Any] | None = Field(default_factory=dict)
@ -266,10 +267,10 @@ class ToolProviderListApi(Resource):
user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
query = ToolProviderListQuery.model_validate(raw_args)
return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type)
return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
@ -411,7 +412,7 @@ class ToolApiProviderAddApi(Resource):
payload.icon,
payload.credentials,
payload.schema_type,
payload.schema,
payload.schema_,
payload.privacy_policy or "",
payload.custom_disclaimer or "",
payload.labels or [],
@ -428,7 +429,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
query = UrlQuery.model_validate(raw_args)
return ApiToolManageService.get_api_tool_provider_remote_schema(
@ -448,7 +449,7 @@ class ToolApiProviderListToolsApi(Resource):
user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
query = ProviderQuery.model_validate(raw_args)
return jsonable_encoder(
@ -482,7 +483,7 @@ class ToolApiProviderUpdateApi(Resource):
payload.icon,
payload.credentials,
payload.schema_type,
payload.schema,
payload.schema_,
payload.privacy_policy,
payload.custom_disclaimer,
payload.labels or [],
@ -520,7 +521,7 @@ class ToolApiProviderGetApi(Resource):
user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
query = ProviderQuery.model_validate(raw_args)
return ApiToolManageService.get_api_tool_provider(
@ -555,7 +556,7 @@ class ToolApiProviderSchemaApi(Resource):
payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {})
return ApiToolManageService.parser_api_schema(
schema=payload.schema,
schema=payload.schema_,
)
@ -575,7 +576,7 @@ class ToolApiProviderPreviousTestApi(Resource):
payload.credentials,
payload.parameters,
payload.schema_type,
payload.schema,
payload.schema_,
)
@ -665,7 +666,7 @@ class ToolWorkflowProviderGetApi(Resource):
user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
query = WorkflowToolGetQuery.model_validate(raw_args)
if query.workflow_tool_id:
@ -696,7 +697,7 @@ class ToolWorkflowProviderListToolApi(Resource):
user_id = user.id
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
query = WorkflowToolListQuery.model_validate(raw_args)
return jsonable_encoder(
@ -1155,7 +1156,7 @@ class ToolMCPUpdateApi(Resource):
@console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource):
def get(self):
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
query = MCPCallbackQuery.model_validate(raw_args)
state_key = query.state
authorization_code = query.code

View File

@ -7,9 +7,7 @@ from werkzeug.exceptions import Unauthorized
from constants import HEADER_NAME_APP_CODE
from controllers.common import fields
from controllers.web import web_ns
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from controllers.common.schema import register_schema_models
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from libs.passport import PassportService
from libs.token import extract_webapp_passport
@ -19,6 +17,10 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
from . import web_ns
from .error import AppUnavailableError
from .wraps import WebApiResource
logger = logging.getLogger(__name__)
@ -29,6 +31,9 @@ class AppAccessModeQuery(BaseModel):
app_code: str | None = Field(default=None, alias="appCode", description="Application code")
register_schema_models(web_ns, AppAccessModeQuery)
@web_ns.route("/parameters")
class AppParameterApi(WebApiResource):
"""Resource for app variables."""
@ -99,7 +104,7 @@ class AppAccessMode(Resource):
}
)
def get(self):
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
args = AppAccessModeQuery.model_validate(raw_args)
features = FeatureService.get_system_features()

View File

@ -18,8 +18,6 @@ from controllers.web.error import (
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from ..common import register_schema_models
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
@ -32,6 +30,9 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError,
)
from ..common.schema import register_schema_models
from ..web.wraps import WebApiResource
class TextToAudioPayload(BaseModel):
message_id: str | None = None
@ -46,6 +47,7 @@ class TextToAudioPayload(BaseModel):
return value
return uuid_value(value)
register_schema_models(web_ns, TextToAudioPayload)
logger = logging.getLogger(__name__)

View File

@ -82,7 +82,7 @@ class ConversationListApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
query = ConversationListQuery.model_validate(raw_args)
try:
@ -171,7 +171,7 @@ class ConversationRenameApi(WebApiResource):
payload = ConversationRenamePayload.model_validate(web_ns.payload or {})
try:
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) # type: ignore
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -108,7 +108,7 @@ class MessageListApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
query = MessageListQuery.model_validate(raw_args)
try:
@ -182,7 +182,7 @@ class MessageMoreLikeThisApi(WebApiResource):
message_id = str(message_id)
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
query = MessageMoreLikeThisQuery.model_validate(raw_args)
streaming = query.response_mode == "streaming"

View File

@ -85,7 +85,7 @@ class SavedMessageListApi(WebApiResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
raw_args = request.args.to_dict(flat=True) # type: ignore[arg-type]
raw_args = request.args.to_dict()
query = SavedMessageListQuery.model_validate(raw_args)
return SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)

View File

@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""

View File

@ -247,14 +247,14 @@ class ApiToolManageService:
credentials: dict,
schema_type: str,
schema: str,
privacy_policy: str,
privacy_policy: str | None,
custom_disclaimer: str,
labels: list[str],
):
"""
update api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
if schema_type not in list(ApiProviderSchemaType):
raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip()

View File

@ -1,8 +1,6 @@
import json
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
@ -11,8 +9,8 @@ from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
@ -39,12 +37,10 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",
labels: list[str] | None = None,
):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
@ -77,7 +73,7 @@ class WorkflowToolManageService:
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
parameter_configuration=json.dumps([p.model_dump() for p in parameters]),
privacy_policy=privacy_policy,
version=workflow.version,
)
@ -108,7 +104,7 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",
labels: list[str] | None = None,
):
@ -126,8 +122,6 @@ class WorkflowToolManageService:
:param labels: labels
:return: the updated tool
"""
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
@ -166,7 +160,7 @@ class WorkflowToolManageService:
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
workflow_tool_provider.parameter_configuration = json.dumps(obj=[p.model_dump() for p in parameters])
workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now()