chore: encoder

This commit is contained in:
Yeuoly 2024-04-02 19:25:52 +08:00
parent 0202469254
commit 01c6a35966
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
3 changed files with 16 additions and 41 deletions

View File

@ -9,6 +9,7 @@ from typing import Any, Union
from flask import current_app
from core.agent.entities import AgentToolEntity
from core.model_runtime.utils.encoders import jsonable_encoder
from core.provider_manager import ProviderManager
from core.tools import *
from core.tools.entities.common_entities import I18nObject
@ -29,7 +30,6 @@ from core.tools.utils.configuration import (
ToolConfigurationManager,
ToolParameterConfigurationManager,
)
from core.tools.utils.encoder import serialize_base_model_dict
from core.utils.module_import_helper import load_single_subclass_from_source
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
@ -545,7 +545,7 @@ class ToolManager:
"content": "\ud83d\ude01"
}
return json.loads(serialize_base_model_dict({
return jsonable_encoder({
'schema_type': provider.schema_type,
'schema': provider.schema,
'tools': provider.tools,
@ -553,7 +553,7 @@ class ToolManager:
'description': provider.description,
'credentials': masked_credentials,
'privacy_policy': provider.privacy_policy
}))
})
@classmethod
def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:

View File

@ -1,21 +0,0 @@
from pydantic import BaseModel
def serialize_base_model_array(l: list[BaseModel]) -> str:
class _BaseModel(BaseModel):
__root__: list[BaseModel]
"""
{"__root__": [BaseModel, BaseModel, ...]}
"""
return _BaseModel(__root__=l).json()
def serialize_base_model_dict(b: dict) -> str:
class _BaseModel(BaseModel):
__root__: dict
"""
{"__root__": {BaseModel}}
"""
return _BaseModel(__root__=b).json()

View File

@ -3,6 +3,7 @@ import logging
from httpx import get
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_entities import (
@ -18,7 +19,6 @@ from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolConfigurationManager
from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider
@ -89,9 +89,9 @@ class ToolManageService:
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name)
return json.loads(serialize_base_model_array([
return jsonable_encoder([
v for _, v in (provider.credentials_schema or {}).items()
]))
])
@staticmethod
def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:
@ -152,14 +152,12 @@ class ToolManageService:
),
]
return json.loads(serialize_base_model_dict(
{
'schema_type': schema_type,
'parameters_schema': tool_bundles,
'credentials_schema': credentials_schema,
'warning': warnings
}
))
return jsonable_encoder({
'schema_type': schema_type,
'parameters_schema': tool_bundles,
'credentials_schema': credentials_schema,
'warning': warnings
})
except Exception as e:
raise ValueError(f'invalid schema: {str(e)}')
@ -213,7 +211,7 @@ class ToolManageService:
schema=schema,
description=extra_info.get('description', ''),
schema_type_str=schema_type,
tools_str=serialize_base_model_array(tool_bundles),
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str={},
privacy_policy=privacy_policy
)
@ -406,7 +404,7 @@ class ToolManageService:
provider.schema = schema
provider.description = extra_info.get('description', '')
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
provider.tools_str = serialize_base_model_array(tool_bundles)
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
if 'auth_type' not in credentials:
@ -515,9 +513,7 @@ class ToolManageService:
) for tool in tools
]
return json.loads(
serialize_base_model_array(result)
)
return jsonable_encoder(result)
@staticmethod
def delete_api_tool_provider(
@ -586,7 +582,7 @@ class ToolManageService:
schema=schema,
description='',
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
tools_str=serialize_base_model_array(tool_bundles),
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str=json.dumps(credentials),
)