mirror of https://github.com/langgenius/dify.git
refactor tools
This commit is contained in:
parent
50a5cfe56a
commit
1fa3b9cfd8
|
|
@ -10,6 +10,7 @@ from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
|||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeApp,
|
||||
RequestInvokeEncrypt,
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeModeration,
|
||||
RequestInvokeNode,
|
||||
|
|
@ -132,6 +133,14 @@ class PluginInvokeAppApi(Resource):
|
|||
PluginAppBackwardsInvocation.convert_to_event_stream(response)
|
||||
)
|
||||
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_tenant
|
||||
@plugin_data(payload_type=RequestInvokeEncrypt)
|
||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeEncrypt):
|
||||
""""""
|
||||
|
||||
api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
|
||||
api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
|
||||
api.add_resource(PluginInvokeRerankApi, '/invoke/rerank')
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ def enterprise_inner_api_user_auth(view):
|
|||
user_id = user_id.split(" ")[1]
|
||||
|
||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||
if not inner_api_key:
|
||||
raise ValueError("inner api key not found")
|
||||
|
||||
data_to_sign = f"DIFY {user_id}"
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||
node_data: BaseNodeData
|
||||
|
||||
node_run_index: int
|
||||
inputs: dict = None
|
||||
inputs: Optional[dict] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class CommonParameterType(Enum):
|
||||
SECRET_INPUT = "secret-input"
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
FILE = "file"
|
||||
BOOLEAN = "boolean"
|
||||
APP_SELECTOR = "app-selector"
|
||||
MODEL_CONFIG = "model-config"
|
||||
|
||||
|
||||
class AppSelectorScope(Enum):
|
||||
ALL = "all"
|
||||
CHAT = "chat"
|
||||
WORKFLOW = "workflow"
|
||||
COMPLETION = "completion"
|
||||
|
||||
|
||||
class ModelConfigScope(Enum):
|
||||
LLM = "llm"
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = "rerank"
|
||||
TTS = "tts"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
VISION = "vision"
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderQuotaType
|
||||
|
||||
|
|
@ -100,3 +102,52 @@ class ModelSettings(BaseModel):
|
|||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
class BasicProviderConfig(BaseModel):
|
||||
"""
|
||||
Base model class for common provider settings like credentials
|
||||
"""
|
||||
class Type(Enum):
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_CONFIG = CommonParameterType.MODEL_CONFIG.value
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
|
||||
@staticmethod
|
||||
def default(value: str) -> str:
|
||||
return ""
|
||||
|
||||
type: Type = Field(..., description="The type of the credentials")
|
||||
name: str = Field(..., description="The name of the credentials")
|
||||
|
||||
class ProviderConfig(BasicProviderConfig):
|
||||
"""
|
||||
Model class for common provider settings like credentials
|
||||
"""
|
||||
class Option(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
scope: AppSelectorScope | ModelConfigScope | None
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str]] = None
|
||||
options: Optional[list[Option]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
url: Optional[str] = None
|
||||
placeholder: Optional[I18nObject] = None
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
tool_file_manager = {
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
tool_file_manager: dict[str, Any] = {
|
||||
'manager': None
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
|
|
@ -30,11 +32,10 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
|||
"""
|
||||
Request to invoke LLM
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.LLM
|
||||
mode: str
|
||||
model_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
prompt_messages: list[PromptMessage]
|
||||
prompt_messages: list[PromptMessage] = Field(default_factory=list)
|
||||
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
|
||||
stop: Optional[list[str]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
|
|
@ -105,4 +106,11 @@ class RequestInvokeApp(BaseModel):
|
|||
conversation_id: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
files: list[dict] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RequestInvokeEncrypt(BaseModel):
|
||||
"""
|
||||
Request to encryption
|
||||
"""
|
||||
opt: Literal["encrypt", "decrypt"]
|
||||
data: dict = Field(default_factory=dict)
|
||||
config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType
|
||||
from core.tools.tool.tool import ToolParameter
|
||||
|
||||
|
||||
|
|
@ -62,4 +62,4 @@ class UserToolProvider(BaseModel):
|
|||
}
|
||||
|
||||
class UserToolProviderCredentials(BaseModel):
|
||||
credentials: dict[str, ToolProviderCredentials]
|
||||
credentials: dict[str, ProviderConfig]
|
||||
|
|
@ -3,6 +3,7 @@ from typing import Any, Optional, Union, cast
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
|
|
@ -137,12 +138,12 @@ class ToolParameterOption(BaseModel):
|
|||
|
||||
class ToolParameter(BaseModel):
|
||||
class ToolParameterType(str, Enum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
SECRET_INPUT = "secret-input"
|
||||
FILE = "file"
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
|
|
@ -151,16 +152,17 @@ class ToolParameter(BaseModel):
|
|||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
human_description: Optional[I18nObject] = Field(None, description="The description presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(None, description="The placeholder presented to the user")
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
scope: AppSelectorScope | ModelConfigScope | None = None
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
required: Optional[bool] = False
|
||||
default: Optional[Union[float, int, str]] = None
|
||||
min: Optional[Union[float, int]] = None
|
||||
max: Optional[Union[float, int]] = None
|
||||
options: Optional[list[ToolParameterOption]] = None
|
||||
options: list[ToolParameterOption] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(cls,
|
||||
|
|
@ -211,57 +213,6 @@ class ToolIdentity(BaseModel):
|
|||
provider: str = Field(..., description="The provider of the tool")
|
||||
icon: Optional[str] = None
|
||||
|
||||
class ToolCredentialsOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
class ToolProviderCredentials(BaseModel):
|
||||
class CredentialsType(Enum):
|
||||
SECRET_INPUT = "secret-input"
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
BOOLEAN = "boolean"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
|
||||
@staticmethod
|
||||
def default(value: str) -> str:
|
||||
return ""
|
||||
|
||||
name: str = Field(..., description="The name of the credentials")
|
||||
type: CredentialsType = Field(..., description="The type of the credentials")
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str]] = None
|
||||
options: Optional[list[ToolCredentialsOption]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
url: Optional[str] = None
|
||||
placeholder: Optional[I18nObject] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'name': self.name,
|
||||
'type': self.type.value,
|
||||
'required': self.required,
|
||||
'default': self.default,
|
||||
'options': self.options,
|
||||
'help': self.help.to_dict() if self.help else None,
|
||||
'label': self.label.to_dict() if self.label else None,
|
||||
'url': self.url,
|
||||
'placeholder': self.placeholder.to_dict() if self.placeholder else None,
|
||||
}
|
||||
|
||||
class ToolRuntimeVariableType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from core.tools.entities.common_entities import I18nObject
|
|||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ProviderConfig,
|
||||
ToolCredentialsOption,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
|
|
@ -20,10 +20,10 @@ class ApiToolProviderController(ToolProviderController):
|
|||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
|
||||
credentials_schema = {
|
||||
'auth_type': ToolProviderCredentials(
|
||||
'auth_type': ProviderConfig(
|
||||
name='auth_type',
|
||||
required=True,
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
options=[
|
||||
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
|
||||
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
|
||||
|
|
@ -38,30 +38,30 @@ class ApiToolProviderController(ToolProviderController):
|
|||
if auth_type == ApiProviderAuthType.API_KEY:
|
||||
credentials_schema = {
|
||||
**credentials_schema,
|
||||
'api_key_header': ToolProviderCredentials(
|
||||
'api_key_header': ProviderConfig(
|
||||
name='api_key_header',
|
||||
required=False,
|
||||
default='api_key',
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
help=I18nObject(
|
||||
en_US='The header name of the api key',
|
||||
zh_Hans='携带 api key 的 header 名称'
|
||||
)
|
||||
),
|
||||
'api_key_value': ToolProviderCredentials(
|
||||
'api_key_value': ProviderConfig(
|
||||
name='api_key_value',
|
||||
required=True,
|
||||
type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
|
||||
type=ProviderConfig.Type.SECRET_INPUT,
|
||||
help=I18nObject(
|
||||
en_US='The api key',
|
||||
zh_Hans='api key的值'
|
||||
)
|
||||
),
|
||||
'api_key_header_prefix': ToolProviderCredentials(
|
||||
'api_key_header_prefix': ProviderConfig(
|
||||
name='api_key_header_prefix',
|
||||
required=False,
|
||||
default='basic',
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
help=I18nObject(
|
||||
en_US='The prefix of the api key header',
|
||||
zh_Hans='api key header 的前缀'
|
||||
|
|
|
|||
|
|
@ -1,115 +0,0 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import PublishedAppTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AppToolProviderEntity(ToolProviderController):
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.APP
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def get_tools(self, user_id: str) -> list[Tool]:
|
||||
db_tools: list[PublishedAppTool] = db.session.query(PublishedAppTool).filter(
|
||||
PublishedAppTool.user_id == user_id,
|
||||
).all()
|
||||
|
||||
if not db_tools or len(db_tools) == 0:
|
||||
return []
|
||||
|
||||
tools: list[Tool] = []
|
||||
|
||||
for db_tool in db_tools:
|
||||
tool = {
|
||||
'identity': {
|
||||
'author': db_tool.author,
|
||||
'name': db_tool.tool_name,
|
||||
'label': {
|
||||
'en_US': db_tool.tool_name,
|
||||
'zh_Hans': db_tool.tool_name
|
||||
},
|
||||
'icon': ''
|
||||
},
|
||||
'description': {
|
||||
'human': {
|
||||
'en_US': db_tool.description_i18n.en_US,
|
||||
'zh_Hans': db_tool.description_i18n.zh_Hans
|
||||
},
|
||||
'llm': db_tool.llm_description
|
||||
},
|
||||
'parameters': []
|
||||
}
|
||||
# get app from db
|
||||
app: App = db_tool.app
|
||||
|
||||
if not app:
|
||||
logger.error(f"app {db_tool.app_id} not found")
|
||||
continue
|
||||
|
||||
app_model_config: AppModelConfig = app.app_model_config
|
||||
user_input_form_list = app_model_config.user_input_form_list
|
||||
for input_form in user_input_form_list:
|
||||
# get type
|
||||
form_type = input_form.keys()[0]
|
||||
default = input_form[form_type]['default']
|
||||
required = input_form[form_type]['required']
|
||||
label = input_form[form_type]['label']
|
||||
variable_name = input_form[form_type]['variable_name']
|
||||
options = input_form[form_type].get('options', [])
|
||||
if form_type == 'paragraph' or form_type == 'text-input':
|
||||
tool['parameters'].append(ToolParameter(
|
||||
name=variable_name,
|
||||
label=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
llm_description=label,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=required,
|
||||
default=default
|
||||
))
|
||||
elif form_type == 'select':
|
||||
tool['parameters'].append(ToolParameter(
|
||||
name=variable_name,
|
||||
label=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
llm_description=label,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
required=required,
|
||||
default=default,
|
||||
options=[ToolParameterOption(
|
||||
value=option,
|
||||
label=I18nObject(
|
||||
en_US=option,
|
||||
zh_Hans=option
|
||||
)
|
||||
) for option in options]
|
||||
))
|
||||
|
||||
tools.append(Tool(**tool))
|
||||
return tools
|
||||
|
|
@ -2,22 +2,23 @@ from abc import abstractmethod
|
|||
from os import listdir, path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolNotFoundError,
|
||||
ToolParameterValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
tools: list[BuiltinTool] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
|
||||
super().__init__(**data)
|
||||
|
|
@ -41,7 +42,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
'credentials_schema': provider_yaml.get('credentials_for_provider', None),
|
||||
})
|
||||
|
||||
def _get_builtin_tools(self) -> list[Tool]:
|
||||
def _get_builtin_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
|
|
@ -72,7 +73,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
self.tools = tools
|
||||
return tools
|
||||
|
||||
def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
|
||||
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
|
|
@ -83,7 +84,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
|
||||
return self.credentials_schema.copy()
|
||||
|
||||
def get_tools(self) -> list[Tool]:
|
||||
def get_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
|
|
@ -91,24 +92,12 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
"""
|
||||
return self._get_builtin_tools()
|
||||
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
def get_tool(self, tool_name: str) -> BuiltinTool | None:
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
|
||||
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
|
||||
"""
|
||||
returns the parameters of the tool
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:return: list of parameters
|
||||
"""
|
||||
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f'tool {tool_name} not found')
|
||||
return tool.parameters
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
"""
|
||||
|
|
@ -143,67 +132,6 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
returns the labels of the provider
|
||||
"""
|
||||
return self.identity.tags or []
|
||||
|
||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the parameters of the tool and set the default value if needed
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param tool_parameters: the parameters of the tool
|
||||
"""
|
||||
tool_parameters_schema = self.get_parameters(tool_name)
|
||||
|
||||
tool_parameters_need_to_validate: dict[str, ToolParameter] = {}
|
||||
for parameter in tool_parameters_schema:
|
||||
tool_parameters_need_to_validate[parameter.name] = parameter
|
||||
|
||||
for parameter in tool_parameters:
|
||||
if parameter not in tool_parameters_need_to_validate:
|
||||
raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
|
||||
|
||||
# check type
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be string')
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
if not isinstance(tool_parameters[parameter], int | float):
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be number')
|
||||
|
||||
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
|
||||
|
||||
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if not isinstance(tool_parameters[parameter], bool):
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be string')
|
||||
|
||||
options = parameter_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolParameterValidationError(f'parameter {parameter} options should be list')
|
||||
|
||||
if tool_parameters[parameter] not in [x.value for x in options]:
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
|
||||
|
||||
tool_parameters_need_to_validate.pop(parameter)
|
||||
|
||||
for parameter in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
if parameter_schema.required:
|
||||
raise ToolParameterValidationError(f'parameter {parameter} is required')
|
||||
|
||||
# the parameter is not set currently, set the default value if needed
|
||||
if parameter_schema.default is not None:
|
||||
default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
|
||||
parameter_schema.type)
|
||||
tool_parameters[parameter] = default_value
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,25 +1,23 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
|
||||
class ToolProviderController(BaseModel, ABC):
|
||||
identity: Optional[ToolProviderIdentity] = None
|
||||
tools: Optional[list[Tool]] = None
|
||||
credentials_schema: Optional[dict[str, ToolProviderCredentials]] = None
|
||||
identity: ToolProviderIdentity
|
||||
tools: list[Tool] = Field(default_factory=list)
|
||||
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||
|
||||
def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
|
||||
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
|
|
@ -27,15 +25,6 @@ class ToolProviderController(BaseModel, ABC):
|
|||
"""
|
||||
return self.credentials_schema.copy()
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> list[Tool]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
:return: list of tools
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
"""
|
||||
|
|
@ -45,18 +34,6 @@ class ToolProviderController(BaseModel, ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
|
||||
"""
|
||||
returns the parameters of the tool
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:return: list of parameters
|
||||
"""
|
||||
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f'tool {tool_name} not found')
|
||||
return tool.parameters
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
|
|
@ -66,66 +43,6 @@ class ToolProviderController(BaseModel, ABC):
|
|||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the parameters of the tool and set the default value if needed
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param tool_parameters: the parameters of the tool
|
||||
"""
|
||||
tool_parameters_schema = self.get_parameters(tool_name)
|
||||
|
||||
tool_parameters_need_to_validate: dict[str, ToolParameter] = {}
|
||||
for parameter in tool_parameters_schema:
|
||||
tool_parameters_need_to_validate[parameter.name] = parameter
|
||||
|
||||
for parameter in tool_parameters:
|
||||
if parameter not in tool_parameters_need_to_validate:
|
||||
raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
|
||||
|
||||
# check type
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be string')
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
if not isinstance(tool_parameters[parameter], int | float):
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be number')
|
||||
|
||||
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
|
||||
|
||||
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if not isinstance(tool_parameters[parameter], bool):
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be string')
|
||||
|
||||
options = parameter_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolParameterValidationError(f'parameter {parameter} options should be list')
|
||||
|
||||
if tool_parameters[parameter] not in [x.value for x in options]:
|
||||
raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
|
||||
|
||||
tool_parameters_need_to_validate.pop(parameter)
|
||||
|
||||
for parameter in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
if parameter_schema.required:
|
||||
raise ToolParameterValidationError(f'parameter {parameter} is required')
|
||||
|
||||
# the parameter is not set currently, set the default value if needed
|
||||
if parameter_schema.default is not None:
|
||||
tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
|
||||
parameter_schema.type)
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
|
@ -136,7 +53,7 @@ class ToolProviderController(BaseModel, ABC):
|
|||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
credentials_need_to_validate: dict[str, ToolProviderCredentials] = {}
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
|
|
@ -146,12 +63,12 @@ class ToolProviderController(BaseModel, ABC):
|
|||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
|
||||
credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
|
||||
if credential_schema == ProviderConfig.Type.SECRET_INPUT or \
|
||||
credential_schema == ProviderConfig.Type.TEXT_INPUT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
||||
|
||||
elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
||||
|
||||
|
|
@ -173,9 +90,9 @@ class ToolProviderController(BaseModel, ABC):
|
|||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
|
||||
credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
|
||||
credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
|
||||
if credential_schema.type == ProviderConfig.Type.SECRET_INPUT or \
|
||||
credential_schema.type == ProviderConfig.Type.TEXT_INPUT or \
|
||||
credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
|
@ -28,6 +31,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
|||
|
||||
class WorkflowToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
tools: list[WorkflowTool] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController':
|
||||
|
|
@ -71,16 +75,17 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
:param app: the app
|
||||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow = db.session.query(Workflow).filter(
|
||||
workflow: Workflow | None = db.session.query(Workflow).filter(
|
||||
Workflow.app_id == db_provider.app_id,
|
||||
Workflow.version == db_provider.version
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError('workflow not found')
|
||||
|
||||
# fetch start node
|
||||
graph: dict = workflow.graph_dict
|
||||
features_dict: dict = workflow.features_dict
|
||||
graph: Mapping = workflow.graph_dict
|
||||
features_dict: Mapping = workflow.features_dict
|
||||
features = WorkflowAppConfigManager.convert_features(
|
||||
config_dict=features_dict,
|
||||
app_mode=AppMode.WORKFLOW
|
||||
|
|
@ -89,7 +94,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
parameters = db_provider.parameter_configurations
|
||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||
|
||||
user = db_provider.user
|
||||
|
|
@ -99,7 +104,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
variable = fetch_workflow_variable(parameter.name)
|
||||
if variable:
|
||||
parameter_type = None
|
||||
options = None
|
||||
options = []
|
||||
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
|
||||
raise ValueError(f'unsupported variable type {variable.type}')
|
||||
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
|
||||
|
|
@ -185,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
label=db_provider.label
|
||||
)
|
||||
|
||||
def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
|
||||
def get_tools(self, tenant_id: str) -> list[WorkflowTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
|
|
@ -196,7 +201,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
db_providers: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
).first()
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class Tool(BaseModel, ABC):
|
|||
invoke_from: Optional[InvokeFrom] = None
|
||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||
credentials: Optional[dict[str, Any]] = None
|
||||
runtime_parameters: Optional[dict[str, Any]] = None
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
runtime: Optional[Runtime] = None
|
||||
variables: Optional[ToolRuntimeVariablePool] = None
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import mimetypes
|
|||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
|
|
@ -22,6 +22,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
|||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
|
@ -57,7 +58,7 @@ class ToolManager:
|
|||
return cls._builtin_providers[provider]
|
||||
|
||||
@classmethod
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None:
|
||||
"""
|
||||
get the builtin tool
|
||||
|
||||
|
|
@ -78,7 +79,7 @@ class ToolManager:
|
|||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
-> Union[BuiltinTool, ApiTool, WorkflowTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
|
|
@ -90,19 +91,21 @@ class ToolManager:
|
|||
"""
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
|
||||
if not builtin_tool:
|
||||
raise ValueError(f"tool {tool_name} not found")
|
||||
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = cls.get_builtin_provider(provider_id)
|
||||
if not provider_controller.need_credentials:
|
||||
return builtin_tool.fork_tool_runtime(runtime={
|
||||
return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': {},
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
}))
|
||||
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
builtin_provider: BuiltinToolProvider | None = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_id,
|
||||
).first()
|
||||
|
|
@ -117,13 +120,13 @@ class ToolManager:
|
|||
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
return builtin_tool.fork_tool_runtime(runtime={
|
||||
return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': decrypted_credentials,
|
||||
'runtime_parameters': {},
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
}))
|
||||
|
||||
elif provider_type == ToolProviderType.API:
|
||||
if tenant_id is None:
|
||||
|
|
@ -135,12 +138,12 @@ class ToolManager:
|
|||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
|
||||
return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': decrypted_credentials,
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
}))
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
workflow_provider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
|
|
@ -154,12 +157,12 @@ class ToolManager:
|
|||
db_provider=workflow_provider
|
||||
)
|
||||
|
||||
return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
|
||||
return cast(WorkflowTool, controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': {},
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
}))
|
||||
elif provider_type == ToolProviderType.APP:
|
||||
raise NotImplementedError('app provider not implemented')
|
||||
else:
|
||||
|
|
@ -220,7 +223,10 @@ class ToolManager:
|
|||
identity_id=f'AGENT.{app_id}'
|
||||
)
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
|
||||
|
||||
if not tool_entity.runtime:
|
||||
raise Exception("tool missing runtime")
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
|
|
@ -258,6 +264,9 @@ class ToolManager:
|
|||
if runtime_parameters:
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
|
||||
if not tool_entity.runtime:
|
||||
raise Exception("tool missing runtime")
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
|
|
@ -304,20 +313,20 @@ class ToolManager:
|
|||
"""
|
||||
list all the builtin providers
|
||||
"""
|
||||
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
||||
if provider.startswith('__'):
|
||||
for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
||||
if provider_path.startswith('__'):
|
||||
continue
|
||||
|
||||
if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
|
||||
if provider.startswith('__'):
|
||||
if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider_path)):
|
||||
if provider_path.startswith('__'):
|
||||
continue
|
||||
|
||||
# init provider
|
||||
try:
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
module_name=f'core.tools.provider.builtin.{provider_path}.{provider_path}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
'provider', 'builtin', provider_path, f'{provider_path}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
cls._builtin_providers[provider.identity.name] = provider
|
||||
|
|
@ -387,8 +396,8 @@ class ToolManager:
|
|||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name
|
||||
):
|
||||
|
|
@ -461,7 +470,7 @@ class ToolManager:
|
|||
|
||||
:return: the provider controller, the credentials
|
||||
"""
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.id == provider_id,
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
).first()
|
||||
|
|
@ -486,22 +495,22 @@ class ToolManager:
|
|||
"""
|
||||
get tool provider
|
||||
"""
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
provider_obj: ApiToolProvider| None = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
).first()
|
||||
|
||||
if provider is None:
|
||||
if provider_obj is None:
|
||||
raise ValueError(f'you have not added provider {provider}')
|
||||
|
||||
try:
|
||||
credentials = json.loads(provider.credentials_str) or {}
|
||||
credentials = json.loads(provider_obj.credentials_str) or {}
|
||||
except:
|
||||
credentials = {}
|
||||
|
||||
# package tool provider controller
|
||||
controller = ApiToolProviderController.from_db(
|
||||
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||
provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||
)
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
|
|
@ -510,7 +519,7 @@ class ToolManager:
|
|||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
|
||||
try:
|
||||
icon = json.loads(provider.icon)
|
||||
icon = json.loads(provider_obj.icon)
|
||||
except:
|
||||
icon = {
|
||||
"background": "#252525",
|
||||
|
|
@ -521,14 +530,14 @@ class ToolManager:
|
|||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
return jsonable_encoder({
|
||||
'schema_type': provider.schema_type,
|
||||
'schema': provider.schema,
|
||||
'tools': provider.tools,
|
||||
'schema_type': provider_obj.schema_type,
|
||||
'schema': provider_obj.schema,
|
||||
'tools': provider_obj.tools,
|
||||
'icon': icon,
|
||||
'description': provider.description,
|
||||
'description': provider_obj.description,
|
||||
'credentials': masked_credentials,
|
||||
'privacy_policy': provider.privacy_policy,
|
||||
'custom_disclaimer': provider.custom_disclaimer,
|
||||
'privacy_policy': provider_obj.privacy_policy,
|
||||
'custom_disclaimer': provider_obj.custom_disclaimer,
|
||||
'labels': labels,
|
||||
})
|
||||
|
||||
|
|
@ -551,25 +560,29 @@ class ToolManager:
|
|||
+ "/icon")
|
||||
elif provider_type == ToolProviderType.API:
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
api_provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.id == provider_id
|
||||
).first()
|
||||
return json.loads(provider.icon)
|
||||
if not api_provider:
|
||||
raise ValueError("api tool not found")
|
||||
|
||||
return json.loads(api_provider.icon)
|
||||
except:
|
||||
return {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
workflow_provider: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == provider_id
|
||||
).first()
|
||||
if provider is None:
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
|
||||
|
||||
return json.loads(provider.icon)
|
||||
return json.loads(workflow_provider.icon)
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from core.helper import encrypter
|
|||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ProviderConfig,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
|
|
@ -36,7 +36,7 @@ class ToolConfigurationManager(BaseModel):
|
|||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentials_schema()
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
||||
if field.type == ProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
|
||||
credentials[field_name] = encrypted
|
||||
|
|
@ -54,7 +54,7 @@ class ToolConfigurationManager(BaseModel):
|
|||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentials_schema()
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
||||
if field.type == ProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
if len(credentials[field_name]) > 6:
|
||||
credentials[field_name] = \
|
||||
|
|
@ -84,7 +84,7 @@ class ToolConfigurationManager(BaseModel):
|
|||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentials_schema()
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
||||
if field.type == ProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
try:
|
||||
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
|
@ -13,7 +15,7 @@ class WorkflowToolConfigurationUtils:
|
|||
raise ValueError('invalid parameter configuration')
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
|
||||
def get_workflow_graph_variables(cls, graph: Mapping) -> list[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
|
|
@ -44,5 +46,3 @@ class WorkflowToolConfigurationUtils:
|
|||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError('parameter configuration mismatch, please republish the tool to update')
|
||||
|
||||
return True
|
||||
|
|
@ -10,8 +10,8 @@ from core.tools.entities.tool_bundle import ApiToolBundle
|
|||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ApiProviderSchemaType,
|
||||
ProviderConfig,
|
||||
ToolCredentialsOption,
|
||||
ToolProviderCredentials,
|
||||
)
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
|
|
@ -39,9 +39,9 @@ class ApiToolManageService:
|
|||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
credentials_schema = [
|
||||
ToolProviderCredentials(
|
||||
ProviderConfig(
|
||||
name="auth_type",
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
required=True,
|
||||
default="none",
|
||||
options=[
|
||||
|
|
@ -50,17 +50,17 @@ class ApiToolManageService:
|
|||
],
|
||||
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
ProviderConfig(
|
||||
name="api_key_header",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
|
||||
default="api_key",
|
||||
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
ProviderConfig(
|
||||
name="api_key_value",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
|
||||
default="",
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from core.tools.entities.common_entities import I18nObject
|
|||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ProviderConfig,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
|
|
@ -92,7 +92,7 @@ class ToolTransformService:
|
|||
# get credentials schema
|
||||
schema = provider_controller.get_credentials_schema()
|
||||
for name, value in schema.items():
|
||||
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
|
||||
result.masked_credentials[name] = ProviderConfig.Type.default(value.type)
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
|
|
|
|||
Loading…
Reference in New Issue