refactor tools

This commit is contained in:
Yeuoly 2024-08-30 14:23:14 +08:00
parent 50a5cfe56a
commit 1fa3b9cfd8
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
20 changed files with 239 additions and 435 deletions

View File

@ -10,6 +10,7 @@ from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.entities.request import (
RequestInvokeApp,
RequestInvokeEncrypt,
RequestInvokeLLM,
RequestInvokeModeration,
RequestInvokeNode,
@ -132,6 +133,14 @@ class PluginInvokeAppApi(Resource):
PluginAppBackwardsInvocation.convert_to_event_stream(response)
)
class PluginInvokeEncryptApi(Resource):
@setup_required
@plugin_inner_api_only
@get_tenant
@plugin_data(payload_type=RequestInvokeEncrypt)
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeEncrypt):
""""""
api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
api.add_resource(PluginInvokeRerankApi, '/invoke/rerank')

View File

@ -46,6 +46,8 @@ def enterprise_inner_api_user_auth(view):
user_id = user_id.split(" ")[1]
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key:
raise ValueError("inner api key not found")
data_to_sign = f"DIFY {user_id}"

View File

@ -60,7 +60,7 @@ class QueueIterationStartEvent(AppQueueEvent):
node_data: BaseNodeData
node_run_index: int
inputs: dict = None
inputs: Optional[dict] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[dict] = None

View File

@ -0,0 +1,30 @@
from enum import Enum
class CommonParameterType(Enum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
STRING = "string"
NUMBER = "number"
FILE = "file"
BOOLEAN = "boolean"
APP_SELECTOR = "app-selector"
MODEL_CONFIG = "model-config"
class AppSelectorScope(Enum):
ALL = "all"
CHAT = "chat"
WORKFLOW = "workflow"
COMPLETION = "completion"
class ModelConfigScope(Enum):
LLM = "llm"
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
TTS = "tts"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"

View File

@ -1,8 +1,10 @@
from enum import Enum
from typing import Optional
from typing import Optional, Union
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType
@ -100,3 +102,52 @@ class ModelSettings(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class BasicProviderConfig(BaseModel):
"""
Base model class for common provider settings like credentials
"""
class Type(Enum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
SELECT = CommonParameterType.SELECT.value
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_CONFIG = CommonParameterType.MODEL_CONFIG.value
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
@staticmethod
def default(value: str) -> str:
return ""
type: Type = Field(..., description="The type of the credentials")
name: str = Field(..., description="The name of the credentials")
class ProviderConfig(BasicProviderConfig):
"""
Model class for common provider settings like credentials
"""
class Option(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
scope: AppSelectorScope | ModelConfigScope | None
required: bool = False
default: Optional[Union[int, str]] = None
options: Optional[list[Option]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
url: Optional[str] = None
placeholder: Optional[I18nObject] = None

View File

@ -1,4 +1,9 @@
tool_file_manager = {
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from core.tools.tool_file_manager import ToolFileManager
tool_file_manager: dict[str, Any] = {
'manager': None
}

View File

@ -1,7 +1,9 @@
from collections.abc import Mapping
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.entities.provider_entities import BasicProviderConfig
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -30,11 +32,10 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
"""
Request to invoke LLM
"""
model_type: ModelType = ModelType.LLM
mode: str
model_parameters: dict[str, Any] = Field(default_factory=dict)
prompt_messages: list[PromptMessage]
prompt_messages: list[PromptMessage] = Field(default_factory=list)
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
stop: Optional[list[str]] = Field(default_factory=list)
stream: Optional[bool] = False
@ -105,4 +106,11 @@ class RequestInvokeApp(BaseModel):
conversation_id: Optional[str] = None
user: Optional[str] = None
files: list[dict] = Field(default_factory=list)
class RequestInvokeEncrypt(BaseModel):
"""
Request to encryption
"""
opt: Literal["encrypt", "decrypt"]
data: dict = Field(default_factory=dict)
config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType
from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType
from core.tools.tool.tool import ToolParameter
@ -62,4 +62,4 @@ class UserToolProvider(BaseModel):
}
class UserToolProviderCredentials(BaseModel):
credentials: dict[str, ToolProviderCredentials]
credentials: dict[str, ProviderConfig]

View File

@ -3,6 +3,7 @@ from typing import Any, Optional, Union, cast
from pydantic import BaseModel, Field, field_validator
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.tools.entities.common_entities import I18nObject
@ -137,12 +138,12 @@ class ToolParameterOption(BaseModel):
class ToolParameter(BaseModel):
class ToolParameterType(str, Enum):
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
SELECT = "select"
SECRET_INPUT = "secret-input"
FILE = "file"
STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
@ -151,16 +152,17 @@ class ToolParameter(BaseModel):
name: str = Field(..., description="The name of the parameter")
label: I18nObject = Field(..., description="The label presented to the user")
human_description: Optional[I18nObject] = Field(None, description="The description presented to the user")
placeholder: Optional[I18nObject] = Field(None, description="The placeholder presented to the user")
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
type: ToolParameterType = Field(..., description="The type of the parameter")
scope: AppSelectorScope | ModelConfigScope | None = None
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: Optional[str] = None
required: Optional[bool] = False
default: Optional[Union[float, int, str]] = None
min: Optional[Union[float, int]] = None
max: Optional[Union[float, int]] = None
options: Optional[list[ToolParameterOption]] = None
options: list[ToolParameterOption] = Field(default_factory=list)
@classmethod
def get_simple_instance(cls,
@ -211,57 +213,6 @@ class ToolIdentity(BaseModel):
provider: str = Field(..., description="The provider of the tool")
icon: Optional[str] = None
class ToolCredentialsOption(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
class ToolProviderCredentials(BaseModel):
class CredentialsType(Enum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
BOOLEAN = "boolean"
@classmethod
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
@staticmethod
def default(value: str) -> str:
return ""
name: str = Field(..., description="The name of the credentials")
type: CredentialsType = Field(..., description="The type of the credentials")
required: bool = False
default: Optional[Union[int, str]] = None
options: Optional[list[ToolCredentialsOption]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
url: Optional[str] = None
placeholder: Optional[I18nObject] = None
def to_dict(self) -> dict:
return {
'name': self.name,
'type': self.type.value,
'required': self.required,
'default': self.default,
'options': self.options,
'help': self.help.to_dict() if self.help else None,
'label': self.label.to_dict() if self.label else None,
'url': self.url,
'placeholder': self.placeholder.to_dict() if self.placeholder else None,
}
class ToolRuntimeVariableType(Enum):
TEXT = "text"
IMAGE = "image"

View File

@ -3,8 +3,8 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ProviderConfig,
ToolCredentialsOption,
ToolProviderCredentials,
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
@ -20,10 +20,10 @@ class ApiToolProviderController(ToolProviderController):
@staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
credentials_schema = {
'auth_type': ToolProviderCredentials(
'auth_type': ProviderConfig(
name='auth_type',
required=True,
type=ToolProviderCredentials.CredentialsType.SELECT,
type=ProviderConfig.Type.SELECT,
options=[
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='')),
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
@ -38,30 +38,30 @@ class ApiToolProviderController(ToolProviderController):
if auth_type == ApiProviderAuthType.API_KEY:
credentials_schema = {
**credentials_schema,
'api_key_header': ToolProviderCredentials(
'api_key_header': ProviderConfig(
name='api_key_header',
required=False,
default='api_key',
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
type=ProviderConfig.Type.TEXT_INPUT,
help=I18nObject(
en_US='The header name of the api key',
zh_Hans='携带 api key 的 header 名称'
)
),
'api_key_value': ToolProviderCredentials(
'api_key_value': ProviderConfig(
name='api_key_value',
required=True,
type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
type=ProviderConfig.Type.SECRET_INPUT,
help=I18nObject(
en_US='The api key',
zh_Hans='api key的值'
)
),
'api_key_header_prefix': ToolProviderCredentials(
'api_key_header_prefix': ProviderConfig(
name='api_key_header_prefix',
required=False,
default='basic',
type=ToolProviderCredentials.CredentialsType.SELECT,
type=ProviderConfig.Type.SELECT,
help=I18nObject(
en_US='The prefix of the api key header',
zh_Hans='api key header 的前缀'

View File

@ -1,115 +0,0 @@
import logging
from typing import Any
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool
from extensions.ext_database import db
from models.model import App, AppModelConfig
from models.tools import PublishedAppTool
logger = logging.getLogger(__name__)
class AppToolProviderEntity(ToolProviderController):
@property
def provider_type(self) -> ToolProviderType:
return ToolProviderType.APP
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
pass
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
pass
def get_tools(self, user_id: str) -> list[Tool]:
db_tools: list[PublishedAppTool] = db.session.query(PublishedAppTool).filter(
PublishedAppTool.user_id == user_id,
).all()
if not db_tools or len(db_tools) == 0:
return []
tools: list[Tool] = []
for db_tool in db_tools:
tool = {
'identity': {
'author': db_tool.author,
'name': db_tool.tool_name,
'label': {
'en_US': db_tool.tool_name,
'zh_Hans': db_tool.tool_name
},
'icon': ''
},
'description': {
'human': {
'en_US': db_tool.description_i18n.en_US,
'zh_Hans': db_tool.description_i18n.zh_Hans
},
'llm': db_tool.llm_description
},
'parameters': []
}
# get app from db
app: App = db_tool.app
if not app:
logger.error(f"app {db_tool.app_id} not found")
continue
app_model_config: AppModelConfig = app.app_model_config
user_input_form_list = app_model_config.user_input_form_list
for input_form in user_input_form_list:
# get type
form_type = input_form.keys()[0]
default = input_form[form_type]['default']
required = input_form[form_type]['required']
label = input_form[form_type]['label']
variable_name = input_form[form_type]['variable_name']
options = input_form[form_type].get('options', [])
if form_type == 'paragraph' or form_type == 'text-input':
tool['parameters'].append(ToolParameter(
name=variable_name,
label=I18nObject(
en_US=label,
zh_Hans=label
),
human_description=I18nObject(
en_US=label,
zh_Hans=label
),
llm_description=label,
form=ToolParameter.ToolParameterForm.FORM,
type=ToolParameter.ToolParameterType.STRING,
required=required,
default=default
))
elif form_type == 'select':
tool['parameters'].append(ToolParameter(
name=variable_name,
label=I18nObject(
en_US=label,
zh_Hans=label
),
human_description=I18nObject(
en_US=label,
zh_Hans=label
),
llm_description=label,
form=ToolParameter.ToolParameterForm.FORM,
type=ToolParameter.ToolParameterType.SELECT,
required=required,
default=default,
options=[ToolParameterOption(
value=option,
label=I18nObject(
en_US=option,
zh_Hans=option
)
) for option in options]
))
tools.append(Tool(**tool))
return tools

View File

@ -2,22 +2,23 @@ from abc import abstractmethod
from os import listdir, path
from typing import Any
from pydantic import Field
from core.entities.provider_entities import ProviderConfig
from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolNotFoundError,
ToolParameterValidationError,
ToolProviderNotFoundError,
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.tools.utils.yaml_utils import load_yaml_file
class BuiltinToolProviderController(ToolProviderController):
tools: list[BuiltinTool] = Field(default_factory=list)
def __init__(self, **data: Any) -> None:
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
super().__init__(**data)
@ -41,7 +42,7 @@ class BuiltinToolProviderController(ToolProviderController):
'credentials_schema': provider_yaml.get('credentials_for_provider', None),
})
def _get_builtin_tools(self) -> list[Tool]:
def _get_builtin_tools(self) -> list[BuiltinTool]:
"""
returns a list of tools that the provider can provide
@ -72,7 +73,7 @@ class BuiltinToolProviderController(ToolProviderController):
self.tools = tools
return tools
def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
"""
returns the credentials schema of the provider
@ -83,7 +84,7 @@ class BuiltinToolProviderController(ToolProviderController):
return self.credentials_schema.copy()
def get_tools(self) -> list[Tool]:
def get_tools(self) -> list[BuiltinTool]:
"""
returns a list of tools that the provider can provide
@ -91,24 +92,12 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> Tool:
def get_tool(self, tool_name: str) -> BuiltinTool | None:
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
"""
returns the parameters of the tool
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
if tool is None:
raise ToolNotFoundError(f'tool {tool_name} not found')
return tool.parameters
@property
def need_credentials(self) -> bool:
"""
@ -143,67 +132,6 @@ class BuiltinToolProviderController(ToolProviderController):
returns the labels of the provider
"""
return self.identity.tags or []
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
"""
validate the parameters of the tool and set the default value if needed
:param tool_name: the name of the tool, defined in `get_tools`
:param tool_parameters: the parameters of the tool
"""
tool_parameters_schema = self.get_parameters(tool_name)
tool_parameters_need_to_validate: dict[str, ToolParameter] = {}
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f'parameter {parameter} should be string')
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], int | float):
raise ToolParameterValidationError(f'parameter {parameter} should be number')
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f'parameter {parameter} should be string')
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParameterValidationError(f'parameter {parameter} options should be list')
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
tool_parameters_need_to_validate.pop(parameter)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.required:
raise ToolParameterValidationError(f'parameter {parameter} is required')
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
parameter_schema.type)
tool_parameters[parameter] = default_value
def validate_credentials(self, credentials: dict[str, Any]) -> None:
"""

View File

@ -1,25 +1,23 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderCredentials,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
class ToolProviderController(BaseModel, ABC):
identity: Optional[ToolProviderIdentity] = None
tools: Optional[list[Tool]] = None
credentials_schema: Optional[dict[str, ToolProviderCredentials]] = None
identity: ToolProviderIdentity
tools: list[Tool] = Field(default_factory=list)
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
"""
returns the credentials schema of the provider
@ -27,15 +25,6 @@ class ToolProviderController(BaseModel, ABC):
"""
return self.credentials_schema.copy()
@abstractmethod
def get_tools(self) -> list[Tool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
pass
@abstractmethod
def get_tool(self, tool_name: str) -> Tool:
"""
@ -45,18 +34,6 @@ class ToolProviderController(BaseModel, ABC):
"""
pass
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
"""
returns the parameters of the tool
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
if tool is None:
raise ToolNotFoundError(f'tool {tool_name} not found')
return tool.parameters
@property
def provider_type(self) -> ToolProviderType:
"""
@ -66,66 +43,6 @@ class ToolProviderController(BaseModel, ABC):
"""
return ToolProviderType.BUILT_IN
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
"""
validate the parameters of the tool and set the default value if needed
:param tool_name: the name of the tool, defined in `get_tools`
:param tool_parameters: the parameters of the tool
"""
tool_parameters_schema = self.get_parameters(tool_name)
tool_parameters_need_to_validate: dict[str, ToolParameter] = {}
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f'parameter {parameter} should be string')
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], int | float):
raise ToolParameterValidationError(f'parameter {parameter} should be number')
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f'parameter {parameter} should be string')
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParameterValidationError(f'parameter {parameter} options should be list')
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
tool_parameters_need_to_validate.pop(parameter)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.required:
raise ToolParameterValidationError(f'parameter {parameter} is required')
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
parameter_schema.type)
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
validate the format of the credentials of the provider and set the default value if needed
@ -136,7 +53,7 @@ class ToolProviderController(BaseModel, ABC):
if credentials_schema is None:
return
credentials_need_to_validate: dict[str, ToolProviderCredentials] = {}
credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
@ -146,12 +63,12 @@ class ToolProviderController(BaseModel, ABC):
# check type
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
if credential_schema == ProviderConfig.Type.SECRET_INPUT or \
credential_schema == ProviderConfig.Type.TEXT_INPUT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
elif credential_schema.type == ProviderConfig.Type.SELECT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
@ -173,9 +90,9 @@ class ToolProviderController(BaseModel, ABC):
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
if credential_schema.type == ProviderConfig.Type.SECRET_INPUT or \
credential_schema.type == ProviderConfig.Type.TEXT_INPUT or \
credential_schema.type == ProviderConfig.Type.SELECT:
default_value = str(default_value)
credentials[credential_name] = default_value

View File

@ -1,5 +1,8 @@
from collections.abc import Mapping
from typing import Optional
from pydantic import Field
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject
@ -28,6 +31,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
class WorkflowToolProviderController(ToolProviderController):
provider_id: str
tools: list[WorkflowTool] = Field(default_factory=list)
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController':
@ -71,16 +75,17 @@ class WorkflowToolProviderController(ToolProviderController):
:param app: the app
:return: the tool
"""
workflow: Workflow = db.session.query(Workflow).filter(
workflow: Workflow | None = db.session.query(Workflow).filter(
Workflow.app_id == db_provider.app_id,
Workflow.version == db_provider.version
).first()
if not workflow:
raise ValueError('workflow not found')
# fetch start node
graph: dict = workflow.graph_dict
features_dict: dict = workflow.features_dict
graph: Mapping = workflow.graph_dict
features_dict: Mapping = workflow.features_dict
features = WorkflowAppConfigManager.convert_features(
config_dict=features_dict,
app_mode=AppMode.WORKFLOW
@ -89,7 +94,7 @@ class WorkflowToolProviderController(ToolProviderController):
parameters = db_provider.parameter_configurations
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None)
user = db_provider.user
@ -99,7 +104,7 @@ class WorkflowToolProviderController(ToolProviderController):
variable = fetch_workflow_variable(parameter.name)
if variable:
parameter_type = None
options = None
options = []
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
raise ValueError(f'unsupported variable type {variable.type}')
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
@ -185,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
label=db_provider.label
)
def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
def get_tools(self, tenant_id: str) -> list[WorkflowTool]:
"""
fetch tools from database
@ -196,7 +201,7 @@ class WorkflowToolProviderController(ToolProviderController):
if self.tools is not None:
return self.tools
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
db_providers: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
).first()

View File

@ -55,7 +55,7 @@ class Tool(BaseModel, ABC):
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: Optional[dict[str, Any]] = None
runtime_parameters: Optional[dict[str, Any]] = None
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
runtime: Optional[Runtime] = None
variables: Optional[ToolRuntimeVariablePool] = None

View File

@ -4,7 +4,7 @@ import mimetypes
from collections.abc import Generator
from os import listdir, path
from threading import Lock
from typing import Any, Union
from typing import Any, Union, cast
from configs import dify_config
from core.agent.entities import AgentToolEntity
@ -22,6 +22,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
@ -57,7 +58,7 @@ class ToolManager:
return cls._builtin_providers[provider]
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None:
"""
get the builtin tool
@ -78,7 +79,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-> Union[BuiltinTool, ApiTool]:
-> Union[BuiltinTool, ApiTool, WorkflowTool]:
"""
get the tool runtime
@ -90,19 +91,21 @@ class ToolManager:
"""
if provider_type == ToolProviderType.BUILT_IN:
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
if not builtin_tool:
raise ValueError(f"tool {tool_name} not found")
# check if the builtin tool need credentials
provider_controller = cls.get_builtin_provider(provider_id)
if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(runtime={
return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={
'tenant_id': tenant_id,
'credentials': {},
'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from,
})
}))
# get credentials
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
builtin_provider: BuiltinToolProvider | None = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_id,
).first()
@ -117,13 +120,13 @@ class ToolManager:
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return builtin_tool.fork_tool_runtime(runtime={
return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
'runtime_parameters': {},
'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from,
})
}))
elif provider_type == ToolProviderType.API:
if tenant_id is None:
@ -135,12 +138,12 @@ class ToolManager:
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from,
})
}))
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
@ -154,12 +157,12 @@ class ToolManager:
db_provider=workflow_provider
)
return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
return cast(WorkflowTool, controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
'tenant_id': tenant_id,
'credentials': {},
'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from,
})
}))
elif provider_type == ToolProviderType.APP:
raise NotImplementedError('app provider not implemented')
else:
@ -220,7 +223,10 @@ class ToolManager:
identity_id=f'AGENT.{app_id}'
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if not tool_entity.runtime:
raise Exception("tool missing runtime")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@ -258,6 +264,9 @@ class ToolManager:
if runtime_parameters:
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if not tool_entity.runtime:
raise Exception("tool missing runtime")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@ -304,20 +313,20 @@ class ToolManager:
"""
list all the builtin providers
"""
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
if provider.startswith('__'):
for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
if provider_path.startswith('__'):
continue
if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
if provider.startswith('__'):
if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider_path)):
if provider_path.startswith('__'):
continue
# init provider
try:
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
module_name=f'core.tools.provider.builtin.{provider_path}.{provider_path}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
'provider', 'builtin', provider_path, f'{provider_path}.py'),
parent_type=BuiltinToolProviderController)
provider: BuiltinToolProviderController = provider_class()
cls._builtin_providers[provider.identity.name] = provider
@ -387,8 +396,8 @@ class ToolManager:
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider,
name_func=lambda x: x.identity.name
):
@ -461,7 +470,7 @@ class ToolManager:
:return: the provider controller, the credentials
"""
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter(
ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id,
).first()
@ -486,22 +495,22 @@ class ToolManager:
"""
get tool provider
"""
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
provider_obj: ApiToolProvider| None = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
).first()
if provider is None:
if provider_obj is None:
raise ValueError(f'you have not added provider {provider}')
try:
credentials = json.loads(provider.credentials_str) or {}
credentials = json.loads(provider_obj.credentials_str) or {}
except:
credentials = {}
# package tool provider controller
controller = ApiToolProviderController.from_db(
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
# init tool configuration
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
@ -510,7 +519,7 @@ class ToolManager:
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
try:
icon = json.loads(provider.icon)
icon = json.loads(provider_obj.icon)
except:
icon = {
"background": "#252525",
@ -521,14 +530,14 @@ class ToolManager:
labels = ToolLabelManager.get_tool_labels(controller)
return jsonable_encoder({
'schema_type': provider.schema_type,
'schema': provider.schema,
'tools': provider.tools,
'schema_type': provider_obj.schema_type,
'schema': provider_obj.schema,
'tools': provider_obj.tools,
'icon': icon,
'description': provider.description,
'description': provider_obj.description,
'credentials': masked_credentials,
'privacy_policy': provider.privacy_policy,
'custom_disclaimer': provider.custom_disclaimer,
'privacy_policy': provider_obj.privacy_policy,
'custom_disclaimer': provider_obj.custom_disclaimer,
'labels': labels,
})
@ -551,25 +560,29 @@ class ToolManager:
+ "/icon")
elif provider_type == ToolProviderType.API:
try:
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
api_provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.id == provider_id
).first()
return json.loads(provider.icon)
if not api_provider:
raise ValueError("api tool not found")
return json.loads(api_provider.icon)
except:
return {
"background": "#252525",
"content": "\ud83d\ude01"
}
elif provider_type == ToolProviderType.WORKFLOW:
provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
workflow_provider: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == provider_id
).first()
if provider is None:
if workflow_provider is None:
raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
return json.loads(provider.icon)
return json.loads(workflow_provider.icon)
else:
raise ValueError(f"provider type {provider_type} not found")

View File

@ -7,8 +7,8 @@ from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.entities.tool_entities import (
ProviderConfig,
ToolParameter,
ToolProviderCredentials,
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
@ -36,7 +36,7 @@ class ToolConfigurationManager(BaseModel):
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field.type == ProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted
@ -54,7 +54,7 @@ class ToolConfigurationManager(BaseModel):
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field.type == ProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
if len(credentials[field_name]) > 6:
credentials[field_name] = \
@ -84,7 +84,7 @@ class ToolConfigurationManager(BaseModel):
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field.type == ProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
try:
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])

View File

@ -1,3 +1,5 @@
from collections.abc import Mapping
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
@ -13,7 +15,7 @@ class WorkflowToolConfigurationUtils:
raise ValueError('invalid parameter configuration')
@classmethod
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
def get_workflow_graph_variables(cls, graph: Mapping) -> list[VariableEntity]:
"""
get workflow graph variables
"""
@ -44,5 +46,3 @@ class WorkflowToolConfigurationUtils:
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError('parameter configuration mismatch, please republish the tool to update')
return True

View File

@ -10,8 +10,8 @@ from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ApiProviderSchemaType,
ProviderConfig,
ToolCredentialsOption,
ToolProviderCredentials,
)
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
@ -39,9 +39,9 @@ class ApiToolManageService:
raise ValueError(f"invalid schema: {str(e)}")
credentials_schema = [
ToolProviderCredentials(
ProviderConfig(
name="auth_type",
type=ToolProviderCredentials.CredentialsType.SELECT,
type=ProviderConfig.Type.SELECT,
required=True,
default="none",
options=[
@ -50,17 +50,17 @@ class ApiToolManageService:
],
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
),
ToolProviderCredentials(
ProviderConfig(
name="api_key_header",
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
type=ProviderConfig.Type.TEXT_INPUT,
required=False,
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key headerX-API-KEY"),
default="api_key",
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
),
ToolProviderCredentials(
ProviderConfig(
name="api_key_value",
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
type=ProviderConfig.Type.TEXT_INPUT,
required=False,
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
default="",

View File

@ -8,8 +8,8 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ProviderConfig,
ToolParameter,
ToolProviderCredentials,
ToolProviderType,
)
from core.tools.provider.api_tool_provider import ApiToolProviderController
@ -92,7 +92,7 @@ class ToolTransformService:
# get credentials schema
schema = provider_controller.get_credentials_schema()
for name, value in schema.items():
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
result.masked_credentials[name] = ProviderConfig.Type.default(value.type)
# check if the provider need credentials
if not provider_controller.need_credentials: