diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f3fa18393e..61e29672f9 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,6 +9,7 @@ from typing import Any, Union from flask import current_app from core.agent.entities import AgentToolEntity +from core.model_runtime.utils.encoders import jsonable_encoder from core.provider_manager import ProviderManager from core.tools import * from core.tools.entities.common_entities import I18nObject @@ -29,7 +30,6 @@ from core.tools.utils.configuration import ( ToolConfigurationManager, ToolParameterConfigurationManager, ) -from core.tools.utils.encoder import serialize_base_model_dict from core.utils.module_import_helper import load_single_subclass_from_source from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db @@ -545,7 +545,7 @@ class ToolManager: "content": "\ud83d\ude01" } - return json.loads(serialize_base_model_dict({ + return jsonable_encoder({ 'schema_type': provider.schema_type, 'schema': provider.schema, 'tools': provider.tools, @@ -553,7 +553,7 @@ class ToolManager: 'description': provider.description, 'credentials': masked_credentials, 'privacy_policy': provider.privacy_policy - })) + }) @classmethod def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: diff --git a/api/core/tools/utils/encoder.py b/api/core/tools/utils/encoder.py deleted file mode 100644 index 6d2ea5d7c6..0000000000 --- a/api/core/tools/utils/encoder.py +++ /dev/null @@ -1,21 +0,0 @@ - -from pydantic import BaseModel - - -def serialize_base_model_array(l: list[BaseModel]) -> str: - class _BaseModel(BaseModel): - __root__: list[BaseModel] - - """ - {"__root__": [BaseModel, BaseModel, ...]} - """ - return _BaseModel(__root__=l).json() - -def serialize_base_model_dict(b: dict) -> str: - class _BaseModel(BaseModel): - __root__: dict - - """ - {"__root__": {BaseModel}} - """ - return _BaseModel(__root__=b).json() diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index 29245f1f3b..f342c019f6 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -3,6 +3,7 @@ import logging from httpx import get +from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiBasedToolBundle from core.tools.entities.tool_entities import ( @@ -18,7 +19,6 @@ from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolConfigurationManager -from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider @@ -89,9 +89,9 @@ class ToolManageService: :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name) - return json.loads(serialize_base_model_array([ + return jsonable_encoder([ v for _, v in (provider.credentials_schema or {}).items() - ])) + ]) @staticmethod def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]: @@ -152,14 +152,12 @@ class ToolManageService: ), ] - return json.loads(serialize_base_model_dict( - { - 'schema_type': schema_type, - 'parameters_schema': tool_bundles, - 'credentials_schema': credentials_schema, - 'warning': warnings - } - )) + return jsonable_encoder({ + 'schema_type': schema_type, + 'parameters_schema': tool_bundles, + 'credentials_schema': credentials_schema, + 'warning': warnings + }) except Exception as e: raise ValueError(f'invalid schema: {str(e)}') @@ -213,7 +211,7 @@ class ToolManageService: schema=schema, description=extra_info.get('description', ''), schema_type_str=schema_type, - tools_str=serialize_base_model_array(tool_bundles), + tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str={}, privacy_policy=privacy_policy ) @@ -406,7 +404,7 @@ class ToolManageService: provider.schema = schema provider.description = extra_info.get('description', '') provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value - provider.tools_str = serialize_base_model_array(tool_bundles) + provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy if 'auth_type' not in credentials: @@ -515,9 +513,7 @@ class ToolManageService: ) for tool in tools ] - return json.loads( - serialize_base_model_array(result) - ) + return jsonable_encoder(result) @staticmethod def delete_api_tool_provider( @@ -586,7 +582,7 @@ class ToolManageService: schema=schema, description='', schema_type_str=ApiProviderSchemaType.OPENAPI.value, - tools_str=serialize_base_model_array(tool_bundles), + tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str=json.dumps(credentials), )