refactor tools

This commit is contained in:
Yeuoly 2024-08-30 14:23:14 +08:00
parent 50a5cfe56a
commit 1fa3b9cfd8
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
20 changed files with 239 additions and 435 deletions

View File

@ -10,6 +10,7 @@ from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.entities.request import ( from core.plugin.entities.request import (
RequestInvokeApp, RequestInvokeApp,
RequestInvokeEncrypt,
RequestInvokeLLM, RequestInvokeLLM,
RequestInvokeModeration, RequestInvokeModeration,
RequestInvokeNode, RequestInvokeNode,
@ -132,6 +133,14 @@ class PluginInvokeAppApi(Resource):
PluginAppBackwardsInvocation.convert_to_event_stream(response) PluginAppBackwardsInvocation.convert_to_event_stream(response)
) )
class PluginInvokeEncryptApi(Resource):
@setup_required
@plugin_inner_api_only
@get_tenant
@plugin_data(payload_type=RequestInvokeEncrypt)
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeEncrypt):
""""""
api.add_resource(PluginInvokeLLMApi, '/invoke/llm') api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding') api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
api.add_resource(PluginInvokeRerankApi, '/invoke/rerank') api.add_resource(PluginInvokeRerankApi, '/invoke/rerank')

View File

@ -46,6 +46,8 @@ def enterprise_inner_api_user_auth(view):
user_id = user_id.split(" ")[1] user_id = user_id.split(" ")[1]
inner_api_key = request.headers.get("X-Inner-Api-Key") inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key:
raise ValueError("inner api key not found")
data_to_sign = f"DIFY {user_id}" data_to_sign = f"DIFY {user_id}"

View File

@ -60,7 +60,7 @@ class QueueIterationStartEvent(AppQueueEvent):
node_data: BaseNodeData node_data: BaseNodeData
node_run_index: int node_run_index: int
inputs: dict = None inputs: Optional[dict] = None
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
metadata: Optional[dict] = None metadata: Optional[dict] = None

View File

@ -0,0 +1,30 @@
from enum import Enum
class CommonParameterType(Enum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
STRING = "string"
NUMBER = "number"
FILE = "file"
BOOLEAN = "boolean"
APP_SELECTOR = "app-selector"
MODEL_CONFIG = "model-config"
class AppSelectorScope(Enum):
ALL = "all"
CHAT = "chat"
WORKFLOW = "workflow"
COMPLETION = "completion"
class ModelConfigScope(Enum):
LLM = "llm"
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
TTS = "tts"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"

View File

@ -1,8 +1,10 @@
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, Union
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType from models.provider import ProviderQuotaType
@ -100,3 +102,52 @@ class ModelSettings(BaseModel):
# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
class BasicProviderConfig(BaseModel):
"""
Base model class for common provider settings like credentials
"""
class Type(Enum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
SELECT = CommonParameterType.SELECT.value
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_CONFIG = CommonParameterType.MODEL_CONFIG.value
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
@staticmethod
def default(value: str) -> str:
return ""
type: Type = Field(..., description="The type of the credentials")
name: str = Field(..., description="The name of the credentials")
class ProviderConfig(BasicProviderConfig):
"""
Model class for common provider settings like credentials
"""
class Option(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
scope: AppSelectorScope | ModelConfigScope | None
required: bool = False
default: Optional[Union[int, str]] = None
options: Optional[list[Option]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
url: Optional[str] = None
placeholder: Optional[I18nObject] = None

View File

@ -1,4 +1,9 @@
tool_file_manager = { from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from core.tools.tool_file_manager import ToolFileManager
tool_file_manager: dict[str, Any] = {
'manager': None 'manager': None
} }

View File

@ -1,7 +1,9 @@
from collections.abc import Mapping
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from core.entities.provider_entities import BasicProviderConfig
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
@ -30,11 +32,10 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
""" """
Request to invoke LLM Request to invoke LLM
""" """
model_type: ModelType = ModelType.LLM model_type: ModelType = ModelType.LLM
mode: str mode: str
model_parameters: dict[str, Any] = Field(default_factory=dict) model_parameters: dict[str, Any] = Field(default_factory=dict)
prompt_messages: list[PromptMessage] prompt_messages: list[PromptMessage] = Field(default_factory=list)
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list) tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
stop: Optional[list[str]] = Field(default_factory=list) stop: Optional[list[str]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
@ -105,4 +106,11 @@ class RequestInvokeApp(BaseModel):
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
user: Optional[str] = None user: Optional[str] = None
files: list[dict] = Field(default_factory=list) files: list[dict] = Field(default_factory=list)
class RequestInvokeEncrypt(BaseModel):
"""
Request to encryption
"""
opt: Literal["encrypt", "decrypt"]
data: dict = Field(default_factory=dict)
config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType
from core.tools.tool.tool import ToolParameter from core.tools.tool.tool import ToolParameter
@ -62,4 +62,4 @@ class UserToolProvider(BaseModel):
} }
class UserToolProviderCredentials(BaseModel): class UserToolProviderCredentials(BaseModel):
credentials: dict[str, ToolProviderCredentials] credentials: dict[str, ProviderConfig]

View File

@ -3,6 +3,7 @@ from typing import Any, Optional, Union, cast
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -137,12 +138,12 @@ class ToolParameterOption(BaseModel):
class ToolParameter(BaseModel): class ToolParameter(BaseModel):
class ToolParameterType(str, Enum): class ToolParameterType(str, Enum):
STRING = "string" STRING = CommonParameterType.STRING.value
NUMBER = "number" NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = "boolean" BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = "select" SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = "secret-input" SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = "file" FILE = CommonParameterType.FILE.value
class ToolParameterForm(Enum): class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool SCHEMA = "schema" # should be set while adding tool
@ -151,16 +152,17 @@ class ToolParameter(BaseModel):
name: str = Field(..., description="The name of the parameter") name: str = Field(..., description="The name of the parameter")
label: I18nObject = Field(..., description="The label presented to the user") label: I18nObject = Field(..., description="The label presented to the user")
human_description: Optional[I18nObject] = Field(None, description="The description presented to the user") human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
placeholder: Optional[I18nObject] = Field(None, description="The placeholder presented to the user") placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
type: ToolParameterType = Field(..., description="The type of the parameter") type: ToolParameterType = Field(..., description="The type of the parameter")
scope: AppSelectorScope | ModelConfigScope | None = None
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: Optional[str] = None llm_description: Optional[str] = None
required: Optional[bool] = False required: Optional[bool] = False
default: Optional[Union[float, int, str]] = None default: Optional[Union[float, int, str]] = None
min: Optional[Union[float, int]] = None min: Optional[Union[float, int]] = None
max: Optional[Union[float, int]] = None max: Optional[Union[float, int]] = None
options: Optional[list[ToolParameterOption]] = None options: list[ToolParameterOption] = Field(default_factory=list)
@classmethod @classmethod
def get_simple_instance(cls, def get_simple_instance(cls,
@ -211,57 +213,6 @@ class ToolIdentity(BaseModel):
provider: str = Field(..., description="The provider of the tool") provider: str = Field(..., description="The provider of the tool")
icon: Optional[str] = None icon: Optional[str] = None
class ToolCredentialsOption(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
class ToolProviderCredentials(BaseModel):
class CredentialsType(Enum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
BOOLEAN = "boolean"
@classmethod
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
@staticmethod
def default(value: str) -> str:
return ""
name: str = Field(..., description="The name of the credentials")
type: CredentialsType = Field(..., description="The type of the credentials")
required: bool = False
default: Optional[Union[int, str]] = None
options: Optional[list[ToolCredentialsOption]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
url: Optional[str] = None
placeholder: Optional[I18nObject] = None
def to_dict(self) -> dict:
return {
'name': self.name,
'type': self.type.value,
'required': self.required,
'default': self.default,
'options': self.options,
'help': self.help.to_dict() if self.help else None,
'label': self.label.to_dict() if self.label else None,
'url': self.url,
'placeholder': self.placeholder.to_dict() if self.placeholder else None,
}
class ToolRuntimeVariableType(Enum): class ToolRuntimeVariableType(Enum):
TEXT = "text" TEXT = "text"
IMAGE = "image" IMAGE = "image"

View File

@ -3,8 +3,8 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ProviderConfig,
ToolCredentialsOption, ToolCredentialsOption,
ToolProviderCredentials,
ToolProviderType, ToolProviderType,
) )
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
@ -20,10 +20,10 @@ class ApiToolProviderController(ToolProviderController):
@staticmethod @staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController': def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
credentials_schema = { credentials_schema = {
'auth_type': ToolProviderCredentials( 'auth_type': ProviderConfig(
name='auth_type', name='auth_type',
required=True, required=True,
type=ToolProviderCredentials.CredentialsType.SELECT, type=ProviderConfig.Type.SELECT,
options=[ options=[
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='')), ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='')),
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key')) ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
@ -38,30 +38,30 @@ class ApiToolProviderController(ToolProviderController):
if auth_type == ApiProviderAuthType.API_KEY: if auth_type == ApiProviderAuthType.API_KEY:
credentials_schema = { credentials_schema = {
**credentials_schema, **credentials_schema,
'api_key_header': ToolProviderCredentials( 'api_key_header': ProviderConfig(
name='api_key_header', name='api_key_header',
required=False, required=False,
default='api_key', default='api_key',
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, type=ProviderConfig.Type.TEXT_INPUT,
help=I18nObject( help=I18nObject(
en_US='The header name of the api key', en_US='The header name of the api key',
zh_Hans='携带 api key 的 header 名称' zh_Hans='携带 api key 的 header 名称'
) )
), ),
'api_key_value': ToolProviderCredentials( 'api_key_value': ProviderConfig(
name='api_key_value', name='api_key_value',
required=True, required=True,
type=ToolProviderCredentials.CredentialsType.SECRET_INPUT, type=ProviderConfig.Type.SECRET_INPUT,
help=I18nObject( help=I18nObject(
en_US='The api key', en_US='The api key',
zh_Hans='api key的值' zh_Hans='api key的值'
) )
), ),
'api_key_header_prefix': ToolProviderCredentials( 'api_key_header_prefix': ProviderConfig(
name='api_key_header_prefix', name='api_key_header_prefix',
required=False, required=False,
default='basic', default='basic',
type=ToolProviderCredentials.CredentialsType.SELECT, type=ProviderConfig.Type.SELECT,
help=I18nObject( help=I18nObject(
en_US='The prefix of the api key header', en_US='The prefix of the api key header',
zh_Hans='api key header 的前缀' zh_Hans='api key header 的前缀'

View File

@ -1,115 +0,0 @@
import logging
from typing import Any
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool
from extensions.ext_database import db
from models.model import App, AppModelConfig
from models.tools import PublishedAppTool
logger = logging.getLogger(__name__)
class AppToolProviderEntity(ToolProviderController):
@property
def provider_type(self) -> ToolProviderType:
return ToolProviderType.APP
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
pass
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
pass
def get_tools(self, user_id: str) -> list[Tool]:
db_tools: list[PublishedAppTool] = db.session.query(PublishedAppTool).filter(
PublishedAppTool.user_id == user_id,
).all()
if not db_tools or len(db_tools) == 0:
return []
tools: list[Tool] = []
for db_tool in db_tools:
tool = {
'identity': {
'author': db_tool.author,
'name': db_tool.tool_name,
'label': {
'en_US': db_tool.tool_name,
'zh_Hans': db_tool.tool_name
},
'icon': ''
},
'description': {
'human': {
'en_US': db_tool.description_i18n.en_US,
'zh_Hans': db_tool.description_i18n.zh_Hans
},
'llm': db_tool.llm_description
},
'parameters': []
}
# get app from db
app: App = db_tool.app
if not app:
logger.error(f"app {db_tool.app_id} not found")
continue
app_model_config: AppModelConfig = app.app_model_config
user_input_form_list = app_model_config.user_input_form_list
for input_form in user_input_form_list:
# get type
form_type = input_form.keys()[0]
default = input_form[form_type]['default']
required = input_form[form_type]['required']
label = input_form[form_type]['label']
variable_name = input_form[form_type]['variable_name']
options = input_form[form_type].get('options', [])
if form_type == 'paragraph' or form_type == 'text-input':
tool['parameters'].append(ToolParameter(
name=variable_name,
label=I18nObject(
en_US=label,
zh_Hans=label
),
human_description=I18nObject(
en_US=label,
zh_Hans=label
),
llm_description=label,
form=ToolParameter.ToolParameterForm.FORM,
type=ToolParameter.ToolParameterType.STRING,
required=required,
default=default
))
elif form_type == 'select':
tool['parameters'].append(ToolParameter(
name=variable_name,
label=I18nObject(
en_US=label,
zh_Hans=label
),
human_description=I18nObject(
en_US=label,
zh_Hans=label
),
llm_description=label,
form=ToolParameter.ToolParameterForm.FORM,
type=ToolParameter.ToolParameterType.SELECT,
required=required,
default=default,
options=[ToolParameterOption(
value=option,
label=I18nObject(
en_US=option,
zh_Hans=option
)
) for option in options]
))
tools.append(Tool(**tool))
return tools

View File

@ -2,22 +2,23 @@ from abc import abstractmethod
from os import listdir, path from os import listdir, path
from typing import Any from typing import Any
from pydantic import Field
from core.entities.provider_entities import ProviderConfig
from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import ( from core.tools.errors import (
ToolNotFoundError,
ToolParameterValidationError,
ToolProviderNotFoundError, ToolProviderNotFoundError,
) )
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.tools.utils.yaml_utils import load_yaml_file from core.tools.utils.yaml_utils import load_yaml_file
class BuiltinToolProviderController(ToolProviderController): class BuiltinToolProviderController(ToolProviderController):
tools: list[BuiltinTool] = Field(default_factory=list)
def __init__(self, **data: Any) -> None: def __init__(self, **data: Any) -> None:
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
super().__init__(**data) super().__init__(**data)
@ -41,7 +42,7 @@ class BuiltinToolProviderController(ToolProviderController):
'credentials_schema': provider_yaml.get('credentials_for_provider', None), 'credentials_schema': provider_yaml.get('credentials_for_provider', None),
}) })
def _get_builtin_tools(self) -> list[Tool]: def _get_builtin_tools(self) -> list[BuiltinTool]:
""" """
returns a list of tools that the provider can provide returns a list of tools that the provider can provide
@ -72,7 +73,7 @@ class BuiltinToolProviderController(ToolProviderController):
self.tools = tools self.tools = tools
return tools return tools
def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: def get_credentials_schema(self) -> dict[str, ProviderConfig]:
""" """
returns the credentials schema of the provider returns the credentials schema of the provider
@ -83,7 +84,7 @@ class BuiltinToolProviderController(ToolProviderController):
return self.credentials_schema.copy() return self.credentials_schema.copy()
def get_tools(self) -> list[Tool]: def get_tools(self) -> list[BuiltinTool]:
""" """
returns a list of tools that the provider can provide returns a list of tools that the provider can provide
@ -91,24 +92,12 @@ class BuiltinToolProviderController(ToolProviderController):
""" """
return self._get_builtin_tools() return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> Tool: def get_tool(self, tool_name: str) -> BuiltinTool | None:
""" """
returns the tool that the provider can provide returns the tool that the provider can provide
""" """
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
"""
returns the parameters of the tool
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
if tool is None:
raise ToolNotFoundError(f'tool {tool_name} not found')
return tool.parameters
@property @property
def need_credentials(self) -> bool: def need_credentials(self) -> bool:
""" """
@ -143,67 +132,6 @@ class BuiltinToolProviderController(ToolProviderController):
returns the labels of the provider returns the labels of the provider
""" """
return self.identity.tags or [] return self.identity.tags or []
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
"""
validate the parameters of the tool and set the default value if needed
:param tool_name: the name of the tool, defined in `get_tools`
:param tool_parameters: the parameters of the tool
"""
tool_parameters_schema = self.get_parameters(tool_name)
tool_parameters_need_to_validate: dict[str, ToolParameter] = {}
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f'parameter {parameter} should be string')
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], int | float):
raise ToolParameterValidationError(f'parameter {parameter} should be number')
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f'parameter {parameter} should be string')
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParameterValidationError(f'parameter {parameter} options should be list')
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
tool_parameters_need_to_validate.pop(parameter)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.required:
raise ToolParameterValidationError(f'parameter {parameter} is required')
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
parameter_schema.type)
tool_parameters[parameter] = default_value
def validate_credentials(self, credentials: dict[str, Any]) -> None: def validate_credentials(self, credentials: dict[str, Any]) -> None:
""" """

View File

@ -1,25 +1,23 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderCredentials,
ToolProviderIdentity, ToolProviderIdentity,
ToolProviderType, ToolProviderType,
) )
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
class ToolProviderController(BaseModel, ABC): class ToolProviderController(BaseModel, ABC):
identity: Optional[ToolProviderIdentity] = None identity: ToolProviderIdentity
tools: Optional[list[Tool]] = None tools: list[Tool] = Field(default_factory=list)
credentials_schema: Optional[dict[str, ToolProviderCredentials]] = None credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: def get_credentials_schema(self) -> dict[str, ProviderConfig]:
""" """
returns the credentials schema of the provider returns the credentials schema of the provider
@ -27,15 +25,6 @@ class ToolProviderController(BaseModel, ABC):
""" """
return self.credentials_schema.copy() return self.credentials_schema.copy()
@abstractmethod
def get_tools(self) -> list[Tool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
pass
@abstractmethod @abstractmethod
def get_tool(self, tool_name: str) -> Tool: def get_tool(self, tool_name: str) -> Tool:
""" """
@ -45,18 +34,6 @@ class ToolProviderController(BaseModel, ABC):
""" """
pass pass
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
"""
returns the parameters of the tool
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
if tool is None:
raise ToolNotFoundError(f'tool {tool_name} not found')
return tool.parameters
@property @property
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:
""" """
@ -66,66 +43,6 @@ class ToolProviderController(BaseModel, ABC):
""" """
return ToolProviderType.BUILT_IN return ToolProviderType.BUILT_IN
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
"""
validate the parameters of the tool and set the default value if needed
:param tool_name: the name of the tool, defined in `get_tools`
:param tool_parameters: the parameters of the tool
"""
tool_parameters_schema = self.get_parameters(tool_name)
tool_parameters_need_to_validate: dict[str, ToolParameter] = {}
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f'parameter {parameter} should be string')
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], int | float):
raise ToolParameterValidationError(f'parameter {parameter} should be number')
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f'parameter {parameter} should be string')
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParameterValidationError(f'parameter {parameter} options should be list')
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
tool_parameters_need_to_validate.pop(parameter)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.required:
raise ToolParameterValidationError(f'parameter {parameter} is required')
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
parameter_schema.type)
def validate_credentials_format(self, credentials: dict[str, Any]) -> None: def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
""" """
validate the format of the credentials of the provider and set the default value if needed validate the format of the credentials of the provider and set the default value if needed
@ -136,7 +53,7 @@ class ToolProviderController(BaseModel, ABC):
if credentials_schema is None: if credentials_schema is None:
return return
credentials_need_to_validate: dict[str, ToolProviderCredentials] = {} credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema: for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name] credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
@ -146,12 +63,12 @@ class ToolProviderController(BaseModel, ABC):
# check type # check type
credential_schema = credentials_need_to_validate[credential_name] credential_schema = credentials_need_to_validate[credential_name]
if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ if credential_schema == ProviderConfig.Type.SECRET_INPUT or \
credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT: credential_schema == ProviderConfig.Type.TEXT_INPUT:
if not isinstance(credentials[credential_name], str): if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: elif credential_schema.type == ProviderConfig.Type.SELECT:
if not isinstance(credentials[credential_name], str): if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
@ -173,9 +90,9 @@ class ToolProviderController(BaseModel, ABC):
if credential_schema.default is not None: if credential_schema.default is not None:
default_value = credential_schema.default default_value = credential_schema.default
# parse default value into the correct type # parse default value into the correct type
if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ if credential_schema.type == ProviderConfig.Type.SECRET_INPUT or \
credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \ credential_schema.type == ProviderConfig.Type.TEXT_INPUT or \
credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: credential_schema.type == ProviderConfig.Type.SELECT:
default_value = str(default_value) default_value = str(default_value)
credentials[credential_name] = default_value credentials[credential_name] = default_value

View File

@ -1,5 +1,8 @@
from collections.abc import Mapping
from typing import Optional from typing import Optional
from pydantic import Field
from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -28,6 +31,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
class WorkflowToolProviderController(ToolProviderController): class WorkflowToolProviderController(ToolProviderController):
provider_id: str provider_id: str
tools: list[WorkflowTool] = Field(default_factory=list)
@classmethod @classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController': def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController':
@ -71,16 +75,17 @@ class WorkflowToolProviderController(ToolProviderController):
:param app: the app :param app: the app
:return: the tool :return: the tool
""" """
workflow: Workflow = db.session.query(Workflow).filter( workflow: Workflow | None = db.session.query(Workflow).filter(
Workflow.app_id == db_provider.app_id, Workflow.app_id == db_provider.app_id,
Workflow.version == db_provider.version Workflow.version == db_provider.version
).first() ).first()
if not workflow: if not workflow:
raise ValueError('workflow not found') raise ValueError('workflow not found')
# fetch start node # fetch start node
graph: dict = workflow.graph_dict graph: Mapping = workflow.graph_dict
features_dict: dict = workflow.features_dict features_dict: Mapping = workflow.features_dict
features = WorkflowAppConfigManager.convert_features( features = WorkflowAppConfigManager.convert_features(
config_dict=features_dict, config_dict=features_dict,
app_mode=AppMode.WORKFLOW app_mode=AppMode.WORKFLOW
@ -89,7 +94,7 @@ class WorkflowToolProviderController(ToolProviderController):
parameters = db_provider.parameter_configurations parameters = db_provider.parameter_configurations
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity: def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None) return next(filter(lambda x: x.variable == variable_name, variables), None)
user = db_provider.user user = db_provider.user
@ -99,7 +104,7 @@ class WorkflowToolProviderController(ToolProviderController):
variable = fetch_workflow_variable(parameter.name) variable = fetch_workflow_variable(parameter.name)
if variable: if variable:
parameter_type = None parameter_type = None
options = None options = []
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
raise ValueError(f'unsupported variable type {variable.type}') raise ValueError(f'unsupported variable type {variable.type}')
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
@ -185,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
label=db_provider.label label=db_provider.label
) )
def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: def get_tools(self, tenant_id: str) -> list[WorkflowTool]:
""" """
fetch tools from database fetch tools from database
@ -196,7 +201,7 @@ class WorkflowToolProviderController(ToolProviderController):
if self.tools is not None: if self.tools is not None:
return self.tools return self.tools
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( db_providers: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id, WorkflowToolProvider.app_id == self.provider_id,
).first() ).first()

View File

@ -55,7 +55,7 @@ class Tool(BaseModel, ABC):
invoke_from: Optional[InvokeFrom] = None invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: Optional[dict[str, Any]] = None credentials: Optional[dict[str, Any]] = None
runtime_parameters: Optional[dict[str, Any]] = None runtime_parameters: dict[str, Any] = Field(default_factory=dict)
runtime: Optional[Runtime] = None runtime: Optional[Runtime] = None
variables: Optional[ToolRuntimeVariablePool] = None variables: Optional[ToolRuntimeVariablePool] = None

View File

@ -4,7 +4,7 @@ import mimetypes
from collections.abc import Generator from collections.abc import Generator
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import Any, Union from typing import Any, Union, cast
from configs import dify_config from configs import dify_config
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
@ -22,6 +22,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
from core.tools.tool.api_tool import ApiTool from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.tool_parameter_converter import ToolParameterConverter
@ -57,7 +58,7 @@ class ToolManager:
return cls._builtin_providers[provider] return cls._builtin_providers[provider]
@classmethod @classmethod
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None:
""" """
get the builtin tool get the builtin tool
@ -78,7 +79,7 @@ class ToolManager:
tenant_id: str, tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-> Union[BuiltinTool, ApiTool]: -> Union[BuiltinTool, ApiTool, WorkflowTool]:
""" """
get the tool runtime get the tool runtime
@ -90,19 +91,21 @@ class ToolManager:
""" """
if provider_type == ToolProviderType.BUILT_IN: if provider_type == ToolProviderType.BUILT_IN:
builtin_tool = cls.get_builtin_tool(provider_id, tool_name) builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
if not builtin_tool:
raise ValueError(f"tool {tool_name} not found")
# check if the builtin tool need credentials # check if the builtin tool need credentials
provider_controller = cls.get_builtin_provider(provider_id) provider_controller = cls.get_builtin_provider(provider_id)
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(runtime={ return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={
'tenant_id': tenant_id, 'tenant_id': tenant_id,
'credentials': {}, 'credentials': {},
'invoke_from': invoke_from, 'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from, 'tool_invoke_from': tool_invoke_from,
}) }))
# get credentials # get credentials
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( builtin_provider: BuiltinToolProvider | None = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_id, BuiltinToolProvider.provider == provider_id,
).first() ).first()
@ -117,13 +120,13 @@ class ToolManager:
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return builtin_tool.fork_tool_runtime(runtime={ return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={
'tenant_id': tenant_id, 'tenant_id': tenant_id,
'credentials': decrypted_credentials, 'credentials': decrypted_credentials,
'runtime_parameters': {}, 'runtime_parameters': {},
'invoke_from': invoke_from, 'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from, 'tool_invoke_from': tool_invoke_from,
}) }))
elif provider_type == ToolProviderType.API: elif provider_type == ToolProviderType.API:
if tenant_id is None: if tenant_id is None:
@ -135,12 +138,12 @@ class ToolManager:
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
'tenant_id': tenant_id, 'tenant_id': tenant_id,
'credentials': decrypted_credentials, 'credentials': decrypted_credentials,
'invoke_from': invoke_from, 'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from, 'tool_invoke_from': tool_invoke_from,
}) }))
elif provider_type == ToolProviderType.WORKFLOW: elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = db.session.query(WorkflowToolProvider).filter( workflow_provider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
@ -154,12 +157,12 @@ class ToolManager:
db_provider=workflow_provider db_provider=workflow_provider
) )
return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={ return cast(WorkflowTool, controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
'tenant_id': tenant_id, 'tenant_id': tenant_id,
'credentials': {}, 'credentials': {},
'invoke_from': invoke_from, 'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from, 'tool_invoke_from': tool_invoke_from,
}) }))
elif provider_type == ToolProviderType.APP: elif provider_type == ToolProviderType.APP:
raise NotImplementedError('app provider not implemented') raise NotImplementedError('app provider not implemented')
else: else:
@ -220,7 +223,10 @@ class ToolManager:
identity_id=f'AGENT.{app_id}' identity_id=f'AGENT.{app_id}'
) )
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if not tool_entity.runtime:
raise Exception("tool missing runtime")
tool_entity.runtime.runtime_parameters.update(runtime_parameters) tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity return tool_entity
@ -258,6 +264,9 @@ class ToolManager:
if runtime_parameters: if runtime_parameters:
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if not tool_entity.runtime:
raise Exception("tool missing runtime")
tool_entity.runtime.runtime_parameters.update(runtime_parameters) tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity return tool_entity
@ -304,20 +313,20 @@ class ToolManager:
""" """
list all the builtin providers list all the builtin providers
""" """
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
if provider.startswith('__'): if provider_path.startswith('__'):
continue continue
if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)): if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider_path)):
if provider.startswith('__'): if provider_path.startswith('__'):
continue continue
# init provider # init provider
try: try:
provider_class = load_single_subclass_from_source( provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}', module_name=f'core.tools.provider.builtin.{provider_path}.{provider_path}',
script_path=path.join(path.dirname(path.realpath(__file__)), script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'), 'provider', 'builtin', provider_path, f'{provider_path}.py'),
parent_type=BuiltinToolProviderController) parent_type=BuiltinToolProviderController)
provider: BuiltinToolProviderController = provider_class() provider: BuiltinToolProviderController = provider_class()
cls._builtin_providers[provider.identity.name] = provider cls._builtin_providers[provider.identity.name] = provider
@ -387,8 +396,8 @@ class ToolManager:
for provider in builtin_providers: for provider in builtin_providers:
# handle include, exclude # handle include, exclude
if is_filtered( if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider, data=provider,
name_func=lambda x: x.identity.name name_func=lambda x: x.identity.name
): ):
@ -461,7 +470,7 @@ class ToolManager:
:return: the provider controller, the credentials :return: the provider controller, the credentials
""" """
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter(
ApiToolProvider.id == provider_id, ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
).first() ).first()
@ -486,22 +495,22 @@ class ToolManager:
""" """
get tool provider get tool provider
""" """
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( provider_obj: ApiToolProvider| None = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider, ApiToolProvider.name == provider,
).first() ).first()
if provider is None: if provider_obj is None:
raise ValueError(f'you have not added provider {provider}') raise ValueError(f'you have not added provider {provider}')
try: try:
credentials = json.loads(provider.credentials_str) or {} credentials = json.loads(provider_obj.credentials_str) or {}
except: except:
credentials = {} credentials = {}
# package tool provider controller # package tool provider controller
controller = ApiToolProviderController.from_db( controller = ApiToolProviderController.from_db(
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
) )
# init tool configuration # init tool configuration
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
@ -510,7 +519,7 @@ class ToolManager:
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
try: try:
icon = json.loads(provider.icon) icon = json.loads(provider_obj.icon)
except: except:
icon = { icon = {
"background": "#252525", "background": "#252525",
@ -521,14 +530,14 @@ class ToolManager:
labels = ToolLabelManager.get_tool_labels(controller) labels = ToolLabelManager.get_tool_labels(controller)
return jsonable_encoder({ return jsonable_encoder({
'schema_type': provider.schema_type, 'schema_type': provider_obj.schema_type,
'schema': provider.schema, 'schema': provider_obj.schema,
'tools': provider.tools, 'tools': provider_obj.tools,
'icon': icon, 'icon': icon,
'description': provider.description, 'description': provider_obj.description,
'credentials': masked_credentials, 'credentials': masked_credentials,
'privacy_policy': provider.privacy_policy, 'privacy_policy': provider_obj.privacy_policy,
'custom_disclaimer': provider.custom_disclaimer, 'custom_disclaimer': provider_obj.custom_disclaimer,
'labels': labels, 'labels': labels,
}) })
@ -551,25 +560,29 @@ class ToolManager:
+ "/icon") + "/icon")
elif provider_type == ToolProviderType.API: elif provider_type == ToolProviderType.API:
try: try:
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( api_provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.id == provider_id ApiToolProvider.id == provider_id
).first() ).first()
return json.loads(provider.icon) if not api_provider:
raise ValueError("api tool not found")
return json.loads(api_provider.icon)
except: except:
return { return {
"background": "#252525", "background": "#252525",
"content": "\ud83d\ude01" "content": "\ud83d\ude01"
} }
elif provider_type == ToolProviderType.WORKFLOW: elif provider_type == ToolProviderType.WORKFLOW:
provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( workflow_provider: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == provider_id WorkflowToolProvider.id == provider_id
).first() ).first()
if provider is None:
if workflow_provider is None:
raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
return json.loads(provider.icon) return json.loads(workflow_provider.icon)
else: else:
raise ValueError(f"provider type {provider_type} not found") raise ValueError(f"provider type {provider_type} not found")

View File

@ -7,8 +7,8 @@ from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ProviderConfig,
ToolParameter, ToolParameter,
ToolProviderCredentials,
ToolProviderType, ToolProviderType,
) )
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
@ -36,7 +36,7 @@ class ToolConfigurationManager(BaseModel):
# get fields need to be decrypted # get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema() fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items(): for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: if field.type == ProviderConfig.Type.SECRET_INPUT:
if field_name in credentials: if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name]) encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted credentials[field_name] = encrypted
@ -54,7 +54,7 @@ class ToolConfigurationManager(BaseModel):
# get fields need to be decrypted # get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema() fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items(): for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: if field.type == ProviderConfig.Type.SECRET_INPUT:
if field_name in credentials: if field_name in credentials:
if len(credentials[field_name]) > 6: if len(credentials[field_name]) > 6:
credentials[field_name] = \ credentials[field_name] = \
@ -84,7 +84,7 @@ class ToolConfigurationManager(BaseModel):
# get fields need to be decrypted # get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema() fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items(): for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: if field.type == ProviderConfig.Type.SECRET_INPUT:
if field_name in credentials: if field_name in credentials:
try: try:
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name]) credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])

View File

@ -1,3 +1,5 @@
from collections.abc import Mapping
from core.app.app_config.entities import VariableEntity from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
@ -13,7 +15,7 @@ class WorkflowToolConfigurationUtils:
raise ValueError('invalid parameter configuration') raise ValueError('invalid parameter configuration')
@classmethod @classmethod
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]: def get_workflow_graph_variables(cls, graph: Mapping) -> list[VariableEntity]:
""" """
get workflow graph variables get workflow graph variables
""" """
@ -44,5 +46,3 @@ class WorkflowToolConfigurationUtils:
for parameter in tool_configurations: for parameter in tool_configurations:
if parameter.name not in variable_names: if parameter.name not in variable_names:
raise ValueError('parameter configuration mismatch, please republish the tool to update') raise ValueError('parameter configuration mismatch, please republish the tool to update')
return True

View File

@ -10,8 +10,8 @@ from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ApiProviderSchemaType, ApiProviderSchemaType,
ProviderConfig,
ToolCredentialsOption, ToolCredentialsOption,
ToolProviderCredentials,
) )
from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
@ -39,9 +39,9 @@ class ApiToolManageService:
raise ValueError(f"invalid schema: {str(e)}") raise ValueError(f"invalid schema: {str(e)}")
credentials_schema = [ credentials_schema = [
ToolProviderCredentials( ProviderConfig(
name="auth_type", name="auth_type",
type=ToolProviderCredentials.CredentialsType.SELECT, type=ProviderConfig.Type.SELECT,
required=True, required=True,
default="none", default="none",
options=[ options=[
@ -50,17 +50,17 @@ class ApiToolManageService:
], ],
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
), ),
ToolProviderCredentials( ProviderConfig(
name="api_key_header", name="api_key_header",
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, type=ProviderConfig.Type.TEXT_INPUT,
required=False, required=False,
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key headerX-API-KEY"), placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key headerX-API-KEY"),
default="api_key", default="api_key",
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"), help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
), ),
ToolProviderCredentials( ProviderConfig(
name="api_key_value", name="api_key_value",
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, type=ProviderConfig.Type.TEXT_INPUT,
required=False, required=False,
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"), placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
default="", default="",

View File

@ -8,8 +8,8 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ProviderConfig,
ToolParameter, ToolParameter,
ToolProviderCredentials,
ToolProviderType, ToolProviderType,
) )
from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.api_tool_provider import ApiToolProviderController
@ -92,7 +92,7 @@ class ToolTransformService:
# get credentials schema # get credentials schema
schema = provider_controller.get_credentials_schema() schema = provider_controller.get_credentials_schema()
for name, value in schema.items(): for name, value in schema.items():
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) result.masked_credentials[name] = ProviderConfig.Type.default(value.type)
# check if the provider need credentials # check if the provider need credentials
if not provider_controller.need_credentials: if not provider_controller.need_credentials: