mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 20:17:29 +08:00
refactor tools
This commit is contained in:
parent
50a5cfe56a
commit
1fa3b9cfd8
@ -10,6 +10,7 @@ from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
|||||||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||||
from core.plugin.entities.request import (
|
from core.plugin.entities.request import (
|
||||||
RequestInvokeApp,
|
RequestInvokeApp,
|
||||||
|
RequestInvokeEncrypt,
|
||||||
RequestInvokeLLM,
|
RequestInvokeLLM,
|
||||||
RequestInvokeModeration,
|
RequestInvokeModeration,
|
||||||
RequestInvokeNode,
|
RequestInvokeNode,
|
||||||
@ -132,6 +133,14 @@ class PluginInvokeAppApi(Resource):
|
|||||||
PluginAppBackwardsInvocation.convert_to_event_stream(response)
|
PluginAppBackwardsInvocation.convert_to_event_stream(response)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class PluginInvokeEncryptApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeEncrypt)
|
||||||
|
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeEncrypt):
|
||||||
|
""""""
|
||||||
|
|
||||||
api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
|
api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
|
||||||
api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
|
api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
|
||||||
api.add_resource(PluginInvokeRerankApi, '/invoke/rerank')
|
api.add_resource(PluginInvokeRerankApi, '/invoke/rerank')
|
||||||
|
|||||||
@ -46,6 +46,8 @@ def enterprise_inner_api_user_auth(view):
|
|||||||
user_id = user_id.split(" ")[1]
|
user_id = user_id.split(" ")[1]
|
||||||
|
|
||||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||||
|
if not inner_api_key:
|
||||||
|
raise ValueError("inner api key not found")
|
||||||
|
|
||||||
data_to_sign = f"DIFY {user_id}"
|
data_to_sign = f"DIFY {user_id}"
|
||||||
|
|
||||||
|
|||||||
@ -60,7 +60,7 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||||||
node_data: BaseNodeData
|
node_data: BaseNodeData
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
inputs: dict = None
|
inputs: Optional[dict] = None
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: Optional[str] = None
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
|||||||
30
api/core/entities/parameter_entities.py
Normal file
30
api/core/entities/parameter_entities.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class CommonParameterType(Enum):
|
||||||
|
SECRET_INPUT = "secret-input"
|
||||||
|
TEXT_INPUT = "text-input"
|
||||||
|
SELECT = "select"
|
||||||
|
STRING = "string"
|
||||||
|
NUMBER = "number"
|
||||||
|
FILE = "file"
|
||||||
|
BOOLEAN = "boolean"
|
||||||
|
APP_SELECTOR = "app-selector"
|
||||||
|
MODEL_CONFIG = "model-config"
|
||||||
|
|
||||||
|
|
||||||
|
class AppSelectorScope(Enum):
|
||||||
|
ALL = "all"
|
||||||
|
CHAT = "chat"
|
||||||
|
WORKFLOW = "workflow"
|
||||||
|
COMPLETION = "completion"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfigScope(Enum):
|
||||||
|
LLM = "llm"
|
||||||
|
TEXT_EMBEDDING = "text-embedding"
|
||||||
|
RERANK = "rerank"
|
||||||
|
TTS = "tts"
|
||||||
|
SPEECH2TEXT = "speech2text"
|
||||||
|
MODERATION = "moderation"
|
||||||
|
VISION = "vision"
|
||||||
@ -1,8 +1,10 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from models.provider import ProviderQuotaType
|
from models.provider import ProviderQuotaType
|
||||||
|
|
||||||
@ -100,3 +102,52 @@ class ModelSettings(BaseModel):
|
|||||||
|
|
||||||
# pydantic configs
|
# pydantic configs
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
class BasicProviderConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Base model class for common provider settings like credentials
|
||||||
|
"""
|
||||||
|
class Type(Enum):
|
||||||
|
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||||
|
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
|
||||||
|
SELECT = CommonParameterType.SELECT.value
|
||||||
|
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||||
|
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||||
|
MODEL_CONFIG = CommonParameterType.MODEL_CONFIG.value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
||||||
|
"""
|
||||||
|
Get value of given mode.
|
||||||
|
|
||||||
|
:param value: mode value
|
||||||
|
:return: mode
|
||||||
|
"""
|
||||||
|
for mode in cls:
|
||||||
|
if mode.value == value:
|
||||||
|
return mode
|
||||||
|
raise ValueError(f'invalid mode value {value}')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default(value: str) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
type: Type = Field(..., description="The type of the credentials")
|
||||||
|
name: str = Field(..., description="The name of the credentials")
|
||||||
|
|
||||||
|
class ProviderConfig(BasicProviderConfig):
|
||||||
|
"""
|
||||||
|
Model class for common provider settings like credentials
|
||||||
|
"""
|
||||||
|
class Option(BaseModel):
|
||||||
|
value: str = Field(..., description="The value of the option")
|
||||||
|
label: I18nObject = Field(..., description="The label of the option")
|
||||||
|
|
||||||
|
scope: AppSelectorScope | ModelConfigScope | None
|
||||||
|
required: bool = False
|
||||||
|
default: Optional[Union[int, str]] = None
|
||||||
|
options: Optional[list[Option]] = None
|
||||||
|
label: Optional[I18nObject] = None
|
||||||
|
help: Optional[I18nObject] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
placeholder: Optional[I18nObject] = None
|
||||||
|
|||||||
@ -1,4 +1,9 @@
|
|||||||
tool_file_manager = {
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
||||||
|
tool_file_manager: dict[str, Any] = {
|
||||||
'manager': None
|
'manager': None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -30,11 +32,10 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
|||||||
"""
|
"""
|
||||||
Request to invoke LLM
|
Request to invoke LLM
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: ModelType = ModelType.LLM
|
model_type: ModelType = ModelType.LLM
|
||||||
mode: str
|
mode: str
|
||||||
model_parameters: dict[str, Any] = Field(default_factory=dict)
|
model_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
prompt_messages: list[PromptMessage]
|
prompt_messages: list[PromptMessage] = Field(default_factory=list)
|
||||||
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
|
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
|
||||||
stop: Optional[list[str]] = Field(default_factory=list)
|
stop: Optional[list[str]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
@ -105,4 +106,11 @@ class RequestInvokeApp(BaseModel):
|
|||||||
conversation_id: Optional[str] = None
|
conversation_id: Optional[str] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
files: list[dict] = Field(default_factory=list)
|
files: list[dict] = Field(default_factory=list)
|
||||||
|
|
||||||
|
class RequestInvokeEncrypt(BaseModel):
|
||||||
|
"""
|
||||||
|
Request to encryption
|
||||||
|
"""
|
||||||
|
opt: Literal["encrypt", "decrypt"]
|
||||||
|
data: dict = Field(default_factory=dict)
|
||||||
|
config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType
|
from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType
|
||||||
from core.tools.tool.tool import ToolParameter
|
from core.tools.tool.tool import ToolParameter
|
||||||
|
|
||||||
|
|
||||||
@ -62,4 +62,4 @@ class UserToolProvider(BaseModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
class UserToolProviderCredentials(BaseModel):
|
class UserToolProviderCredentials(BaseModel):
|
||||||
credentials: dict[str, ToolProviderCredentials]
|
credentials: dict[str, ProviderConfig]
|
||||||
@ -3,6 +3,7 @@ from typing import Any, Optional, Union, cast
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
|
||||||
|
|
||||||
@ -137,12 +138,12 @@ class ToolParameterOption(BaseModel):
|
|||||||
|
|
||||||
class ToolParameter(BaseModel):
|
class ToolParameter(BaseModel):
|
||||||
class ToolParameterType(str, Enum):
|
class ToolParameterType(str, Enum):
|
||||||
STRING = "string"
|
STRING = CommonParameterType.STRING.value
|
||||||
NUMBER = "number"
|
NUMBER = CommonParameterType.NUMBER.value
|
||||||
BOOLEAN = "boolean"
|
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||||
SELECT = "select"
|
SELECT = CommonParameterType.SELECT.value
|
||||||
SECRET_INPUT = "secret-input"
|
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||||
FILE = "file"
|
FILE = CommonParameterType.FILE.value
|
||||||
|
|
||||||
class ToolParameterForm(Enum):
|
class ToolParameterForm(Enum):
|
||||||
SCHEMA = "schema" # should be set while adding tool
|
SCHEMA = "schema" # should be set while adding tool
|
||||||
@ -151,16 +152,17 @@ class ToolParameter(BaseModel):
|
|||||||
|
|
||||||
name: str = Field(..., description="The name of the parameter")
|
name: str = Field(..., description="The name of the parameter")
|
||||||
label: I18nObject = Field(..., description="The label presented to the user")
|
label: I18nObject = Field(..., description="The label presented to the user")
|
||||||
human_description: Optional[I18nObject] = Field(None, description="The description presented to the user")
|
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||||
placeholder: Optional[I18nObject] = Field(None, description="The placeholder presented to the user")
|
placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
|
||||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||||
|
scope: AppSelectorScope | ModelConfigScope | None = None
|
||||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||||
llm_description: Optional[str] = None
|
llm_description: Optional[str] = None
|
||||||
required: Optional[bool] = False
|
required: Optional[bool] = False
|
||||||
default: Optional[Union[float, int, str]] = None
|
default: Optional[Union[float, int, str]] = None
|
||||||
min: Optional[Union[float, int]] = None
|
min: Optional[Union[float, int]] = None
|
||||||
max: Optional[Union[float, int]] = None
|
max: Optional[Union[float, int]] = None
|
||||||
options: Optional[list[ToolParameterOption]] = None
|
options: list[ToolParameterOption] = Field(default_factory=list)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_simple_instance(cls,
|
def get_simple_instance(cls,
|
||||||
@ -211,57 +213,6 @@ class ToolIdentity(BaseModel):
|
|||||||
provider: str = Field(..., description="The provider of the tool")
|
provider: str = Field(..., description="The provider of the tool")
|
||||||
icon: Optional[str] = None
|
icon: Optional[str] = None
|
||||||
|
|
||||||
class ToolCredentialsOption(BaseModel):
|
|
||||||
value: str = Field(..., description="The value of the option")
|
|
||||||
label: I18nObject = Field(..., description="The label of the option")
|
|
||||||
|
|
||||||
class ToolProviderCredentials(BaseModel):
|
|
||||||
class CredentialsType(Enum):
|
|
||||||
SECRET_INPUT = "secret-input"
|
|
||||||
TEXT_INPUT = "text-input"
|
|
||||||
SELECT = "select"
|
|
||||||
BOOLEAN = "boolean"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
|
|
||||||
"""
|
|
||||||
Get value of given mode.
|
|
||||||
|
|
||||||
:param value: mode value
|
|
||||||
:return: mode
|
|
||||||
"""
|
|
||||||
for mode in cls:
|
|
||||||
if mode.value == value:
|
|
||||||
return mode
|
|
||||||
raise ValueError(f'invalid mode value {value}')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def default(value: str) -> str:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
name: str = Field(..., description="The name of the credentials")
|
|
||||||
type: CredentialsType = Field(..., description="The type of the credentials")
|
|
||||||
required: bool = False
|
|
||||||
default: Optional[Union[int, str]] = None
|
|
||||||
options: Optional[list[ToolCredentialsOption]] = None
|
|
||||||
label: Optional[I18nObject] = None
|
|
||||||
help: Optional[I18nObject] = None
|
|
||||||
url: Optional[str] = None
|
|
||||||
placeholder: Optional[I18nObject] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
'name': self.name,
|
|
||||||
'type': self.type.value,
|
|
||||||
'required': self.required,
|
|
||||||
'default': self.default,
|
|
||||||
'options': self.options,
|
|
||||||
'help': self.help.to_dict() if self.help else None,
|
|
||||||
'label': self.label.to_dict() if self.label else None,
|
|
||||||
'url': self.url,
|
|
||||||
'placeholder': self.placeholder.to_dict() if self.placeholder else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
class ToolRuntimeVariableType(Enum):
|
class ToolRuntimeVariableType(Enum):
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
|
|||||||
@ -3,8 +3,8 @@ from core.tools.entities.common_entities import I18nObject
|
|||||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ApiProviderAuthType,
|
ApiProviderAuthType,
|
||||||
|
ProviderConfig,
|
||||||
ToolCredentialsOption,
|
ToolCredentialsOption,
|
||||||
ToolProviderCredentials,
|
|
||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.provider.tool_provider import ToolProviderController
|
from core.tools.provider.tool_provider import ToolProviderController
|
||||||
@ -20,10 +20,10 @@ class ApiToolProviderController(ToolProviderController):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
|
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
|
||||||
credentials_schema = {
|
credentials_schema = {
|
||||||
'auth_type': ToolProviderCredentials(
|
'auth_type': ProviderConfig(
|
||||||
name='auth_type',
|
name='auth_type',
|
||||||
required=True,
|
required=True,
|
||||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
type=ProviderConfig.Type.SELECT,
|
||||||
options=[
|
options=[
|
||||||
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
|
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
|
||||||
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
|
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
|
||||||
@ -38,30 +38,30 @@ class ApiToolProviderController(ToolProviderController):
|
|||||||
if auth_type == ApiProviderAuthType.API_KEY:
|
if auth_type == ApiProviderAuthType.API_KEY:
|
||||||
credentials_schema = {
|
credentials_schema = {
|
||||||
**credentials_schema,
|
**credentials_schema,
|
||||||
'api_key_header': ToolProviderCredentials(
|
'api_key_header': ProviderConfig(
|
||||||
name='api_key_header',
|
name='api_key_header',
|
||||||
required=False,
|
required=False,
|
||||||
default='api_key',
|
default='api_key',
|
||||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
type=ProviderConfig.Type.TEXT_INPUT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US='The header name of the api key',
|
en_US='The header name of the api key',
|
||||||
zh_Hans='携带 api key 的 header 名称'
|
zh_Hans='携带 api key 的 header 名称'
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
'api_key_value': ToolProviderCredentials(
|
'api_key_value': ProviderConfig(
|
||||||
name='api_key_value',
|
name='api_key_value',
|
||||||
required=True,
|
required=True,
|
||||||
type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
|
type=ProviderConfig.Type.SECRET_INPUT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US='The api key',
|
en_US='The api key',
|
||||||
zh_Hans='api key的值'
|
zh_Hans='api key的值'
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
'api_key_header_prefix': ToolProviderCredentials(
|
'api_key_header_prefix': ProviderConfig(
|
||||||
name='api_key_header_prefix',
|
name='api_key_header_prefix',
|
||||||
required=False,
|
required=False,
|
||||||
default='basic',
|
default='basic',
|
||||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
type=ProviderConfig.Type.SELECT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US='The prefix of the api key header',
|
en_US='The prefix of the api key header',
|
||||||
zh_Hans='api key header 的前缀'
|
zh_Hans='api key header 的前缀'
|
||||||
|
|||||||
@ -1,115 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.tools.entities.common_entities import I18nObject
|
|
||||||
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
|
|
||||||
from core.tools.provider.tool_provider import ToolProviderController
|
|
||||||
from core.tools.tool.tool import Tool
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.model import App, AppModelConfig
|
|
||||||
from models.tools import PublishedAppTool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class AppToolProviderEntity(ToolProviderController):
|
|
||||||
@property
|
|
||||||
def provider_type(self) -> ToolProviderType:
|
|
||||||
return ToolProviderType.APP
|
|
||||||
|
|
||||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_tools(self, user_id: str) -> list[Tool]:
|
|
||||||
db_tools: list[PublishedAppTool] = db.session.query(PublishedAppTool).filter(
|
|
||||||
PublishedAppTool.user_id == user_id,
|
|
||||||
).all()
|
|
||||||
|
|
||||||
if not db_tools or len(db_tools) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
tools: list[Tool] = []
|
|
||||||
|
|
||||||
for db_tool in db_tools:
|
|
||||||
tool = {
|
|
||||||
'identity': {
|
|
||||||
'author': db_tool.author,
|
|
||||||
'name': db_tool.tool_name,
|
|
||||||
'label': {
|
|
||||||
'en_US': db_tool.tool_name,
|
|
||||||
'zh_Hans': db_tool.tool_name
|
|
||||||
},
|
|
||||||
'icon': ''
|
|
||||||
},
|
|
||||||
'description': {
|
|
||||||
'human': {
|
|
||||||
'en_US': db_tool.description_i18n.en_US,
|
|
||||||
'zh_Hans': db_tool.description_i18n.zh_Hans
|
|
||||||
},
|
|
||||||
'llm': db_tool.llm_description
|
|
||||||
},
|
|
||||||
'parameters': []
|
|
||||||
}
|
|
||||||
# get app from db
|
|
||||||
app: App = db_tool.app
|
|
||||||
|
|
||||||
if not app:
|
|
||||||
logger.error(f"app {db_tool.app_id} not found")
|
|
||||||
continue
|
|
||||||
|
|
||||||
app_model_config: AppModelConfig = app.app_model_config
|
|
||||||
user_input_form_list = app_model_config.user_input_form_list
|
|
||||||
for input_form in user_input_form_list:
|
|
||||||
# get type
|
|
||||||
form_type = input_form.keys()[0]
|
|
||||||
default = input_form[form_type]['default']
|
|
||||||
required = input_form[form_type]['required']
|
|
||||||
label = input_form[form_type]['label']
|
|
||||||
variable_name = input_form[form_type]['variable_name']
|
|
||||||
options = input_form[form_type].get('options', [])
|
|
||||||
if form_type == 'paragraph' or form_type == 'text-input':
|
|
||||||
tool['parameters'].append(ToolParameter(
|
|
||||||
name=variable_name,
|
|
||||||
label=I18nObject(
|
|
||||||
en_US=label,
|
|
||||||
zh_Hans=label
|
|
||||||
),
|
|
||||||
human_description=I18nObject(
|
|
||||||
en_US=label,
|
|
||||||
zh_Hans=label
|
|
||||||
),
|
|
||||||
llm_description=label,
|
|
||||||
form=ToolParameter.ToolParameterForm.FORM,
|
|
||||||
type=ToolParameter.ToolParameterType.STRING,
|
|
||||||
required=required,
|
|
||||||
default=default
|
|
||||||
))
|
|
||||||
elif form_type == 'select':
|
|
||||||
tool['parameters'].append(ToolParameter(
|
|
||||||
name=variable_name,
|
|
||||||
label=I18nObject(
|
|
||||||
en_US=label,
|
|
||||||
zh_Hans=label
|
|
||||||
),
|
|
||||||
human_description=I18nObject(
|
|
||||||
en_US=label,
|
|
||||||
zh_Hans=label
|
|
||||||
),
|
|
||||||
llm_description=label,
|
|
||||||
form=ToolParameter.ToolParameterForm.FORM,
|
|
||||||
type=ToolParameter.ToolParameterType.SELECT,
|
|
||||||
required=required,
|
|
||||||
default=default,
|
|
||||||
options=[ToolParameterOption(
|
|
||||||
value=option,
|
|
||||||
label=I18nObject(
|
|
||||||
en_US=option,
|
|
||||||
zh_Hans=option
|
|
||||||
)
|
|
||||||
) for option in options]
|
|
||||||
))
|
|
||||||
|
|
||||||
tools.append(Tool(**tool))
|
|
||||||
return tools
|
|
||||||
@ -2,22 +2,23 @@ from abc import abstractmethod
|
|||||||
from os import listdir, path
|
from os import listdir, path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||||
from core.tools.errors import (
|
from core.tools.errors import (
|
||||||
ToolNotFoundError,
|
|
||||||
ToolParameterValidationError,
|
|
||||||
ToolProviderNotFoundError,
|
ToolProviderNotFoundError,
|
||||||
)
|
)
|
||||||
from core.tools.provider.tool_provider import ToolProviderController
|
from core.tools.provider.tool_provider import ToolProviderController
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
from core.tools.tool.tool import Tool
|
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
from core.tools.utils.yaml_utils import load_yaml_file
|
from core.tools.utils.yaml_utils import load_yaml_file
|
||||||
|
|
||||||
|
|
||||||
class BuiltinToolProviderController(ToolProviderController):
|
class BuiltinToolProviderController(ToolProviderController):
|
||||||
|
tools: list[BuiltinTool] = Field(default_factory=list)
|
||||||
|
|
||||||
def __init__(self, **data: Any) -> None:
|
def __init__(self, **data: Any) -> None:
|
||||||
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
|
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
@ -41,7 +42,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
'credentials_schema': provider_yaml.get('credentials_for_provider', None),
|
'credentials_schema': provider_yaml.get('credentials_for_provider', None),
|
||||||
})
|
})
|
||||||
|
|
||||||
def _get_builtin_tools(self) -> list[Tool]:
|
def _get_builtin_tools(self) -> list[BuiltinTool]:
|
||||||
"""
|
"""
|
||||||
returns a list of tools that the provider can provide
|
returns a list of tools that the provider can provide
|
||||||
|
|
||||||
@ -72,7 +73,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
self.tools = tools
|
self.tools = tools
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
|
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||||
"""
|
"""
|
||||||
returns the credentials schema of the provider
|
returns the credentials schema of the provider
|
||||||
|
|
||||||
@ -83,7 +84,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
|
|
||||||
return self.credentials_schema.copy()
|
return self.credentials_schema.copy()
|
||||||
|
|
||||||
def get_tools(self) -> list[Tool]:
|
def get_tools(self) -> list[BuiltinTool]:
|
||||||
"""
|
"""
|
||||||
returns a list of tools that the provider can provide
|
returns a list of tools that the provider can provide
|
||||||
|
|
||||||
@ -91,24 +92,12 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
"""
|
"""
|
||||||
return self._get_builtin_tools()
|
return self._get_builtin_tools()
|
||||||
|
|
||||||
def get_tool(self, tool_name: str) -> Tool:
|
def get_tool(self, tool_name: str) -> BuiltinTool | None:
|
||||||
"""
|
"""
|
||||||
returns the tool that the provider can provide
|
returns the tool that the provider can provide
|
||||||
"""
|
"""
|
||||||
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||||
|
|
||||||
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
|
|
||||||
"""
|
|
||||||
returns the parameters of the tool
|
|
||||||
|
|
||||||
:param tool_name: the name of the tool, defined in `get_tools`
|
|
||||||
:return: list of parameters
|
|
||||||
"""
|
|
||||||
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
|
||||||
if tool is None:
|
|
||||||
raise ToolNotFoundError(f'tool {tool_name} not found')
|
|
||||||
return tool.parameters
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def need_credentials(self) -> bool:
|
def need_credentials(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -143,67 +132,6 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||||||
returns the labels of the provider
|
returns the labels of the provider
|
||||||
"""
|
"""
|
||||||
return self.identity.tags or []
|
return self.identity.tags or []
|
||||||
|
|
||||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
|
||||||
"""
|
|
||||||
validate the parameters of the tool and set the default value if needed
|
|
||||||
|
|
||||||
:param tool_name: the name of the tool, defined in `get_tools`
|
|
||||||
:param tool_parameters: the parameters of the tool
|
|
||||||
"""
|
|
||||||
tool_parameters_schema = self.get_parameters(tool_name)
|
|
||||||
|
|
||||||
tool_parameters_need_to_validate: dict[str, ToolParameter] = {}
|
|
||||||
for parameter in tool_parameters_schema:
|
|
||||||
tool_parameters_need_to_validate[parameter.name] = parameter
|
|
||||||
|
|
||||||
for parameter in tool_parameters:
|
|
||||||
if parameter not in tool_parameters_need_to_validate:
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
|
|
||||||
|
|
||||||
# check type
|
|
||||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
|
||||||
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
|
|
||||||
if not isinstance(tool_parameters[parameter], str):
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be string')
|
|
||||||
|
|
||||||
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
|
|
||||||
if not isinstance(tool_parameters[parameter], int | float):
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be number')
|
|
||||||
|
|
||||||
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
|
|
||||||
|
|
||||||
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
|
|
||||||
|
|
||||||
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
|
|
||||||
if not isinstance(tool_parameters[parameter], bool):
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
|
|
||||||
|
|
||||||
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
|
|
||||||
if not isinstance(tool_parameters[parameter], str):
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be string')
|
|
||||||
|
|
||||||
options = parameter_schema.options
|
|
||||||
if not isinstance(options, list):
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} options should be list')
|
|
||||||
|
|
||||||
if tool_parameters[parameter] not in [x.value for x in options]:
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
|
|
||||||
|
|
||||||
tool_parameters_need_to_validate.pop(parameter)
|
|
||||||
|
|
||||||
for parameter in tool_parameters_need_to_validate:
|
|
||||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
|
||||||
if parameter_schema.required:
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} is required')
|
|
||||||
|
|
||||||
# the parameter is not set currently, set the default value if needed
|
|
||||||
if parameter_schema.default is not None:
|
|
||||||
default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
|
|
||||||
parameter_schema.type)
|
|
||||||
tool_parameters[parameter] = default_value
|
|
||||||
|
|
||||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,25 +1,23 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ToolParameter,
|
|
||||||
ToolProviderCredentials,
|
|
||||||
ToolProviderIdentity,
|
ToolProviderIdentity,
|
||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderController(BaseModel, ABC):
|
class ToolProviderController(BaseModel, ABC):
|
||||||
identity: Optional[ToolProviderIdentity] = None
|
identity: ToolProviderIdentity
|
||||||
tools: Optional[list[Tool]] = None
|
tools: list[Tool] = Field(default_factory=list)
|
||||||
credentials_schema: Optional[dict[str, ToolProviderCredentials]] = None
|
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||||
|
|
||||||
def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
|
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||||
"""
|
"""
|
||||||
returns the credentials schema of the provider
|
returns the credentials schema of the provider
|
||||||
|
|
||||||
@ -27,15 +25,6 @@ class ToolProviderController(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
return self.credentials_schema.copy()
|
return self.credentials_schema.copy()
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_tools(self) -> list[Tool]:
|
|
||||||
"""
|
|
||||||
returns a list of tools that the provider can provide
|
|
||||||
|
|
||||||
:return: list of tools
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_tool(self, tool_name: str) -> Tool:
|
def get_tool(self, tool_name: str) -> Tool:
|
||||||
"""
|
"""
|
||||||
@ -45,18 +34,6 @@ class ToolProviderController(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
|
|
||||||
"""
|
|
||||||
returns the parameters of the tool
|
|
||||||
|
|
||||||
:param tool_name: the name of the tool, defined in `get_tools`
|
|
||||||
:return: list of parameters
|
|
||||||
"""
|
|
||||||
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
|
||||||
if tool is None:
|
|
||||||
raise ToolNotFoundError(f'tool {tool_name} not found')
|
|
||||||
return tool.parameters
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_type(self) -> ToolProviderType:
|
def provider_type(self) -> ToolProviderType:
|
||||||
"""
|
"""
|
||||||
@ -66,66 +43,6 @@ class ToolProviderController(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
return ToolProviderType.BUILT_IN
|
return ToolProviderType.BUILT_IN
|
||||||
|
|
||||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
|
||||||
"""
|
|
||||||
validate the parameters of the tool and set the default value if needed
|
|
||||||
|
|
||||||
:param tool_name: the name of the tool, defined in `get_tools`
|
|
||||||
:param tool_parameters: the parameters of the tool
|
|
||||||
"""
|
|
||||||
tool_parameters_schema = self.get_parameters(tool_name)
|
|
||||||
|
|
||||||
tool_parameters_need_to_validate: dict[str, ToolParameter] = {}
|
|
||||||
for parameter in tool_parameters_schema:
|
|
||||||
tool_parameters_need_to_validate[parameter.name] = parameter
|
|
||||||
|
|
||||||
for parameter in tool_parameters:
|
|
||||||
if parameter not in tool_parameters_need_to_validate:
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
|
|
||||||
|
|
||||||
# check type
|
|
||||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
|
||||||
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
|
|
||||||
if not isinstance(tool_parameters[parameter], str):
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be string')
|
|
||||||
|
|
||||||
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
|
|
||||||
if not isinstance(tool_parameters[parameter], int | float):
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be number')
|
|
||||||
|
|
||||||
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
|
|
||||||
|
|
||||||
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
|
|
||||||
|
|
||||||
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
|
|
||||||
if not isinstance(tool_parameters[parameter], bool):
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
|
|
||||||
|
|
||||||
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
|
|
||||||
if not isinstance(tool_parameters[parameter], str):
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be string')
|
|
||||||
|
|
||||||
options = parameter_schema.options
|
|
||||||
if not isinstance(options, list):
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} options should be list')
|
|
||||||
|
|
||||||
if tool_parameters[parameter] not in [x.value for x in options]:
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
|
|
||||||
|
|
||||||
tool_parameters_need_to_validate.pop(parameter)
|
|
||||||
|
|
||||||
for parameter in tool_parameters_need_to_validate:
|
|
||||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
|
||||||
if parameter_schema.required:
|
|
||||||
raise ToolParameterValidationError(f'parameter {parameter} is required')
|
|
||||||
|
|
||||||
# the parameter is not set currently, set the default value if needed
|
|
||||||
if parameter_schema.default is not None:
|
|
||||||
tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
|
|
||||||
parameter_schema.type)
|
|
||||||
|
|
||||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
validate the format of the credentials of the provider and set the default value if needed
|
validate the format of the credentials of the provider and set the default value if needed
|
||||||
@ -136,7 +53,7 @@ class ToolProviderController(BaseModel, ABC):
|
|||||||
if credentials_schema is None:
|
if credentials_schema is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
credentials_need_to_validate: dict[str, ToolProviderCredentials] = {}
|
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||||
for credential_name in credentials_schema:
|
for credential_name in credentials_schema:
|
||||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||||
|
|
||||||
@ -146,12 +63,12 @@ class ToolProviderController(BaseModel, ABC):
|
|||||||
|
|
||||||
# check type
|
# check type
|
||||||
credential_schema = credentials_need_to_validate[credential_name]
|
credential_schema = credentials_need_to_validate[credential_name]
|
||||||
if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
|
if credential_schema == ProviderConfig.Type.SECRET_INPUT or \
|
||||||
credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
|
credential_schema == ProviderConfig.Type.TEXT_INPUT:
|
||||||
if not isinstance(credentials[credential_name], str):
|
if not isinstance(credentials[credential_name], str):
|
||||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
||||||
|
|
||||||
elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
|
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||||
if not isinstance(credentials[credential_name], str):
|
if not isinstance(credentials[credential_name], str):
|
||||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
||||||
|
|
||||||
@ -173,9 +90,9 @@ class ToolProviderController(BaseModel, ABC):
|
|||||||
if credential_schema.default is not None:
|
if credential_schema.default is not None:
|
||||||
default_value = credential_schema.default
|
default_value = credential_schema.default
|
||||||
# parse default value into the correct type
|
# parse default value into the correct type
|
||||||
if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
|
if credential_schema.type == ProviderConfig.Type.SECRET_INPUT or \
|
||||||
credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
|
credential_schema.type == ProviderConfig.Type.TEXT_INPUT or \
|
||||||
credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
|
credential_schema.type == ProviderConfig.Type.SELECT:
|
||||||
default_value = str(default_value)
|
default_value = str(default_value)
|
||||||
|
|
||||||
credentials[credential_name] = default_value
|
credentials[credential_name] = default_value
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
@ -28,6 +31,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
|||||||
|
|
||||||
class WorkflowToolProviderController(ToolProviderController):
|
class WorkflowToolProviderController(ToolProviderController):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
|
tools: list[WorkflowTool] = Field(default_factory=list)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController':
|
def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController':
|
||||||
@ -71,16 +75,17 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
:param app: the app
|
:param app: the app
|
||||||
:return: the tool
|
:return: the tool
|
||||||
"""
|
"""
|
||||||
workflow: Workflow = db.session.query(Workflow).filter(
|
workflow: Workflow | None = db.session.query(Workflow).filter(
|
||||||
Workflow.app_id == db_provider.app_id,
|
Workflow.app_id == db_provider.app_id,
|
||||||
Workflow.version == db_provider.version
|
Workflow.version == db_provider.version
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not workflow:
|
if not workflow:
|
||||||
raise ValueError('workflow not found')
|
raise ValueError('workflow not found')
|
||||||
|
|
||||||
# fetch start node
|
# fetch start node
|
||||||
graph: dict = workflow.graph_dict
|
graph: Mapping = workflow.graph_dict
|
||||||
features_dict: dict = workflow.features_dict
|
features_dict: Mapping = workflow.features_dict
|
||||||
features = WorkflowAppConfigManager.convert_features(
|
features = WorkflowAppConfigManager.convert_features(
|
||||||
config_dict=features_dict,
|
config_dict=features_dict,
|
||||||
app_mode=AppMode.WORKFLOW
|
app_mode=AppMode.WORKFLOW
|
||||||
@ -89,7 +94,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
parameters = db_provider.parameter_configurations
|
parameters = db_provider.parameter_configurations
|
||||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||||
|
|
||||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
|
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
||||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||||
|
|
||||||
user = db_provider.user
|
user = db_provider.user
|
||||||
@ -99,7 +104,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
variable = fetch_workflow_variable(parameter.name)
|
variable = fetch_workflow_variable(parameter.name)
|
||||||
if variable:
|
if variable:
|
||||||
parameter_type = None
|
parameter_type = None
|
||||||
options = None
|
options = []
|
||||||
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
|
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
|
||||||
raise ValueError(f'unsupported variable type {variable.type}')
|
raise ValueError(f'unsupported variable type {variable.type}')
|
||||||
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
|
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
|
||||||
@ -185,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
label=db_provider.label
|
label=db_provider.label
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
|
def get_tools(self, tenant_id: str) -> list[WorkflowTool]:
|
||||||
"""
|
"""
|
||||||
fetch tools from database
|
fetch tools from database
|
||||||
|
|
||||||
@ -196,7 +201,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
if self.tools is not None:
|
if self.tools is not None:
|
||||||
return self.tools
|
return self.tools
|
||||||
|
|
||||||
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
db_providers: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
|
||||||
WorkflowToolProvider.tenant_id == tenant_id,
|
WorkflowToolProvider.tenant_id == tenant_id,
|
||||||
WorkflowToolProvider.app_id == self.provider_id,
|
WorkflowToolProvider.app_id == self.provider_id,
|
||||||
).first()
|
).first()
|
||||||
|
|||||||
@ -55,7 +55,7 @@ class Tool(BaseModel, ABC):
|
|||||||
invoke_from: Optional[InvokeFrom] = None
|
invoke_from: Optional[InvokeFrom] = None
|
||||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||||
credentials: Optional[dict[str, Any]] = None
|
credentials: Optional[dict[str, Any]] = None
|
||||||
runtime_parameters: Optional[dict[str, Any]] = None
|
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
runtime: Optional[Runtime] = None
|
runtime: Optional[Runtime] = None
|
||||||
variables: Optional[ToolRuntimeVariablePool] = None
|
variables: Optional[ToolRuntimeVariablePool] = None
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import mimetypes
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from os import listdir, path
|
from os import listdir, path
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any, Union
|
from typing import Any, Union, cast
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
@ -22,6 +22,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
|||||||
from core.tools.tool.api_tool import ApiTool
|
from core.tools.tool.api_tool import ApiTool
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
|
from core.tools.tool.workflow_tool import WorkflowTool
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||||
@ -57,7 +58,7 @@ class ToolManager:
|
|||||||
return cls._builtin_providers[provider]
|
return cls._builtin_providers[provider]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
|
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None:
|
||||||
"""
|
"""
|
||||||
get the builtin tool
|
get the builtin tool
|
||||||
|
|
||||||
@ -78,7 +79,7 @@ class ToolManager:
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
||||||
-> Union[BuiltinTool, ApiTool]:
|
-> Union[BuiltinTool, ApiTool, WorkflowTool]:
|
||||||
"""
|
"""
|
||||||
get the tool runtime
|
get the tool runtime
|
||||||
|
|
||||||
@ -90,19 +91,21 @@ class ToolManager:
|
|||||||
"""
|
"""
|
||||||
if provider_type == ToolProviderType.BUILT_IN:
|
if provider_type == ToolProviderType.BUILT_IN:
|
||||||
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
|
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
|
||||||
|
if not builtin_tool:
|
||||||
|
raise ValueError(f"tool {tool_name} not found")
|
||||||
|
|
||||||
# check if the builtin tool need credentials
|
# check if the builtin tool need credentials
|
||||||
provider_controller = cls.get_builtin_provider(provider_id)
|
provider_controller = cls.get_builtin_provider(provider_id)
|
||||||
if not provider_controller.need_credentials:
|
if not provider_controller.need_credentials:
|
||||||
return builtin_tool.fork_tool_runtime(runtime={
|
return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={
|
||||||
'tenant_id': tenant_id,
|
'tenant_id': tenant_id,
|
||||||
'credentials': {},
|
'credentials': {},
|
||||||
'invoke_from': invoke_from,
|
'invoke_from': invoke_from,
|
||||||
'tool_invoke_from': tool_invoke_from,
|
'tool_invoke_from': tool_invoke_from,
|
||||||
})
|
}))
|
||||||
|
|
||||||
# get credentials
|
# get credentials
|
||||||
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
builtin_provider: BuiltinToolProvider | None = db.session.query(BuiltinToolProvider).filter(
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
BuiltinToolProvider.provider == provider_id,
|
BuiltinToolProvider.provider == provider_id,
|
||||||
).first()
|
).first()
|
||||||
@ -117,13 +120,13 @@ class ToolManager:
|
|||||||
|
|
||||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||||
|
|
||||||
return builtin_tool.fork_tool_runtime(runtime={
|
return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={
|
||||||
'tenant_id': tenant_id,
|
'tenant_id': tenant_id,
|
||||||
'credentials': decrypted_credentials,
|
'credentials': decrypted_credentials,
|
||||||
'runtime_parameters': {},
|
'runtime_parameters': {},
|
||||||
'invoke_from': invoke_from,
|
'invoke_from': invoke_from,
|
||||||
'tool_invoke_from': tool_invoke_from,
|
'tool_invoke_from': tool_invoke_from,
|
||||||
})
|
}))
|
||||||
|
|
||||||
elif provider_type == ToolProviderType.API:
|
elif provider_type == ToolProviderType.API:
|
||||||
if tenant_id is None:
|
if tenant_id is None:
|
||||||
@ -135,12 +138,12 @@ class ToolManager:
|
|||||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
|
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
|
||||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||||
|
|
||||||
return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
|
return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
|
||||||
'tenant_id': tenant_id,
|
'tenant_id': tenant_id,
|
||||||
'credentials': decrypted_credentials,
|
'credentials': decrypted_credentials,
|
||||||
'invoke_from': invoke_from,
|
'invoke_from': invoke_from,
|
||||||
'tool_invoke_from': tool_invoke_from,
|
'tool_invoke_from': tool_invoke_from,
|
||||||
})
|
}))
|
||||||
elif provider_type == ToolProviderType.WORKFLOW:
|
elif provider_type == ToolProviderType.WORKFLOW:
|
||||||
workflow_provider = db.session.query(WorkflowToolProvider).filter(
|
workflow_provider = db.session.query(WorkflowToolProvider).filter(
|
||||||
WorkflowToolProvider.tenant_id == tenant_id,
|
WorkflowToolProvider.tenant_id == tenant_id,
|
||||||
@ -154,12 +157,12 @@ class ToolManager:
|
|||||||
db_provider=workflow_provider
|
db_provider=workflow_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
|
return cast(WorkflowTool, controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
|
||||||
'tenant_id': tenant_id,
|
'tenant_id': tenant_id,
|
||||||
'credentials': {},
|
'credentials': {},
|
||||||
'invoke_from': invoke_from,
|
'invoke_from': invoke_from,
|
||||||
'tool_invoke_from': tool_invoke_from,
|
'tool_invoke_from': tool_invoke_from,
|
||||||
})
|
}))
|
||||||
elif provider_type == ToolProviderType.APP:
|
elif provider_type == ToolProviderType.APP:
|
||||||
raise NotImplementedError('app provider not implemented')
|
raise NotImplementedError('app provider not implemented')
|
||||||
else:
|
else:
|
||||||
@ -220,7 +223,10 @@ class ToolManager:
|
|||||||
identity_id=f'AGENT.{app_id}'
|
identity_id=f'AGENT.{app_id}'
|
||||||
)
|
)
|
||||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||||
|
|
||||||
|
if not tool_entity.runtime:
|
||||||
|
raise Exception("tool missing runtime")
|
||||||
|
|
||||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||||
return tool_entity
|
return tool_entity
|
||||||
|
|
||||||
@ -258,6 +264,9 @@ class ToolManager:
|
|||||||
if runtime_parameters:
|
if runtime_parameters:
|
||||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||||
|
|
||||||
|
if not tool_entity.runtime:
|
||||||
|
raise Exception("tool missing runtime")
|
||||||
|
|
||||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||||
return tool_entity
|
return tool_entity
|
||||||
|
|
||||||
@ -304,20 +313,20 @@ class ToolManager:
|
|||||||
"""
|
"""
|
||||||
list all the builtin providers
|
list all the builtin providers
|
||||||
"""
|
"""
|
||||||
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
||||||
if provider.startswith('__'):
|
if provider_path.startswith('__'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
|
if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider_path)):
|
||||||
if provider.startswith('__'):
|
if provider_path.startswith('__'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# init provider
|
# init provider
|
||||||
try:
|
try:
|
||||||
provider_class = load_single_subclass_from_source(
|
provider_class = load_single_subclass_from_source(
|
||||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
module_name=f'core.tools.provider.builtin.{provider_path}.{provider_path}',
|
||||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||||
'provider', 'builtin', provider, f'{provider}.py'),
|
'provider', 'builtin', provider_path, f'{provider_path}.py'),
|
||||||
parent_type=BuiltinToolProviderController)
|
parent_type=BuiltinToolProviderController)
|
||||||
provider: BuiltinToolProviderController = provider_class()
|
provider: BuiltinToolProviderController = provider_class()
|
||||||
cls._builtin_providers[provider.identity.name] = provider
|
cls._builtin_providers[provider.identity.name] = provider
|
||||||
@ -387,8 +396,8 @@ class ToolManager:
|
|||||||
for provider in builtin_providers:
|
for provider in builtin_providers:
|
||||||
# handle include, exclude
|
# handle include, exclude
|
||||||
if is_filtered(
|
if is_filtered(
|
||||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||||
data=provider,
|
data=provider,
|
||||||
name_func=lambda x: x.identity.name
|
name_func=lambda x: x.identity.name
|
||||||
):
|
):
|
||||||
@ -461,7 +470,7 @@ class ToolManager:
|
|||||||
|
|
||||||
:return: the provider controller, the credentials
|
:return: the provider controller, the credentials
|
||||||
"""
|
"""
|
||||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter(
|
||||||
ApiToolProvider.id == provider_id,
|
ApiToolProvider.id == provider_id,
|
||||||
ApiToolProvider.tenant_id == tenant_id,
|
ApiToolProvider.tenant_id == tenant_id,
|
||||||
).first()
|
).first()
|
||||||
@ -486,22 +495,22 @@ class ToolManager:
|
|||||||
"""
|
"""
|
||||||
get tool provider
|
get tool provider
|
||||||
"""
|
"""
|
||||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
provider_obj: ApiToolProvider| None = db.session.query(ApiToolProvider).filter(
|
||||||
ApiToolProvider.tenant_id == tenant_id,
|
ApiToolProvider.tenant_id == tenant_id,
|
||||||
ApiToolProvider.name == provider,
|
ApiToolProvider.name == provider,
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if provider is None:
|
if provider_obj is None:
|
||||||
raise ValueError(f'you have not added provider {provider}')
|
raise ValueError(f'you have not added provider {provider}')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
credentials = json.loads(provider.credentials_str) or {}
|
credentials = json.loads(provider_obj.credentials_str) or {}
|
||||||
except:
|
except:
|
||||||
credentials = {}
|
credentials = {}
|
||||||
|
|
||||||
# package tool provider controller
|
# package tool provider controller
|
||||||
controller = ApiToolProviderController.from_db(
|
controller = ApiToolProviderController.from_db(
|
||||||
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||||
)
|
)
|
||||||
# init tool configuration
|
# init tool configuration
|
||||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||||
@ -510,7 +519,7 @@ class ToolManager:
|
|||||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
icon = json.loads(provider.icon)
|
icon = json.loads(provider_obj.icon)
|
||||||
except:
|
except:
|
||||||
icon = {
|
icon = {
|
||||||
"background": "#252525",
|
"background": "#252525",
|
||||||
@ -521,14 +530,14 @@ class ToolManager:
|
|||||||
labels = ToolLabelManager.get_tool_labels(controller)
|
labels = ToolLabelManager.get_tool_labels(controller)
|
||||||
|
|
||||||
return jsonable_encoder({
|
return jsonable_encoder({
|
||||||
'schema_type': provider.schema_type,
|
'schema_type': provider_obj.schema_type,
|
||||||
'schema': provider.schema,
|
'schema': provider_obj.schema,
|
||||||
'tools': provider.tools,
|
'tools': provider_obj.tools,
|
||||||
'icon': icon,
|
'icon': icon,
|
||||||
'description': provider.description,
|
'description': provider_obj.description,
|
||||||
'credentials': masked_credentials,
|
'credentials': masked_credentials,
|
||||||
'privacy_policy': provider.privacy_policy,
|
'privacy_policy': provider_obj.privacy_policy,
|
||||||
'custom_disclaimer': provider.custom_disclaimer,
|
'custom_disclaimer': provider_obj.custom_disclaimer,
|
||||||
'labels': labels,
|
'labels': labels,
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -551,25 +560,29 @@ class ToolManager:
|
|||||||
+ "/icon")
|
+ "/icon")
|
||||||
elif provider_type == ToolProviderType.API:
|
elif provider_type == ToolProviderType.API:
|
||||||
try:
|
try:
|
||||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
api_provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter(
|
||||||
ApiToolProvider.tenant_id == tenant_id,
|
ApiToolProvider.tenant_id == tenant_id,
|
||||||
ApiToolProvider.id == provider_id
|
ApiToolProvider.id == provider_id
|
||||||
).first()
|
).first()
|
||||||
return json.loads(provider.icon)
|
if not api_provider:
|
||||||
|
raise ValueError("api tool not found")
|
||||||
|
|
||||||
|
return json.loads(api_provider.icon)
|
||||||
except:
|
except:
|
||||||
return {
|
return {
|
||||||
"background": "#252525",
|
"background": "#252525",
|
||||||
"content": "\ud83d\ude01"
|
"content": "\ud83d\ude01"
|
||||||
}
|
}
|
||||||
elif provider_type == ToolProviderType.WORKFLOW:
|
elif provider_type == ToolProviderType.WORKFLOW:
|
||||||
provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
workflow_provider: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
|
||||||
WorkflowToolProvider.tenant_id == tenant_id,
|
WorkflowToolProvider.tenant_id == tenant_id,
|
||||||
WorkflowToolProvider.id == provider_id
|
WorkflowToolProvider.id == provider_id
|
||||||
).first()
|
).first()
|
||||||
if provider is None:
|
|
||||||
|
if workflow_provider is None:
|
||||||
raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
|
raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
|
||||||
|
|
||||||
return json.loads(provider.icon)
|
return json.loads(workflow_provider.icon)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"provider type {provider_type} not found")
|
raise ValueError(f"provider type {provider_type} not found")
|
||||||
|
|
||||||
|
|||||||
@ -7,8 +7,8 @@ from core.helper import encrypter
|
|||||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
|
ProviderConfig,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolProviderCredentials,
|
|
||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.provider.tool_provider import ToolProviderController
|
from core.tools.provider.tool_provider import ToolProviderController
|
||||||
@ -36,7 +36,7 @@ class ToolConfigurationManager(BaseModel):
|
|||||||
# get fields need to be decrypted
|
# get fields need to be decrypted
|
||||||
fields = self.provider_controller.get_credentials_schema()
|
fields = self.provider_controller.get_credentials_schema()
|
||||||
for field_name, field in fields.items():
|
for field_name, field in fields.items():
|
||||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
if field.type == ProviderConfig.Type.SECRET_INPUT:
|
||||||
if field_name in credentials:
|
if field_name in credentials:
|
||||||
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
|
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
|
||||||
credentials[field_name] = encrypted
|
credentials[field_name] = encrypted
|
||||||
@ -54,7 +54,7 @@ class ToolConfigurationManager(BaseModel):
|
|||||||
# get fields need to be decrypted
|
# get fields need to be decrypted
|
||||||
fields = self.provider_controller.get_credentials_schema()
|
fields = self.provider_controller.get_credentials_schema()
|
||||||
for field_name, field in fields.items():
|
for field_name, field in fields.items():
|
||||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
if field.type == ProviderConfig.Type.SECRET_INPUT:
|
||||||
if field_name in credentials:
|
if field_name in credentials:
|
||||||
if len(credentials[field_name]) > 6:
|
if len(credentials[field_name]) > 6:
|
||||||
credentials[field_name] = \
|
credentials[field_name] = \
|
||||||
@ -84,7 +84,7 @@ class ToolConfigurationManager(BaseModel):
|
|||||||
# get fields need to be decrypted
|
# get fields need to be decrypted
|
||||||
fields = self.provider_controller.get_credentials_schema()
|
fields = self.provider_controller.get_credentials_schema()
|
||||||
for field_name, field in fields.items():
|
for field_name, field in fields.items():
|
||||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
if field.type == ProviderConfig.Type.SECRET_INPUT:
|
||||||
if field_name in credentials:
|
if field_name in credentials:
|
||||||
try:
|
try:
|
||||||
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
|
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity
|
from core.app.app_config.entities import VariableEntity
|
||||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||||
|
|
||||||
@ -13,7 +15,7 @@ class WorkflowToolConfigurationUtils:
|
|||||||
raise ValueError('invalid parameter configuration')
|
raise ValueError('invalid parameter configuration')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
|
def get_workflow_graph_variables(cls, graph: Mapping) -> list[VariableEntity]:
|
||||||
"""
|
"""
|
||||||
get workflow graph variables
|
get workflow graph variables
|
||||||
"""
|
"""
|
||||||
@ -44,5 +46,3 @@ class WorkflowToolConfigurationUtils:
|
|||||||
for parameter in tool_configurations:
|
for parameter in tool_configurations:
|
||||||
if parameter.name not in variable_names:
|
if parameter.name not in variable_names:
|
||||||
raise ValueError('parameter configuration mismatch, please republish the tool to update')
|
raise ValueError('parameter configuration mismatch, please republish the tool to update')
|
||||||
|
|
||||||
return True
|
|
||||||
@ -10,8 +10,8 @@ from core.tools.entities.tool_bundle import ApiToolBundle
|
|||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ApiProviderAuthType,
|
ApiProviderAuthType,
|
||||||
ApiProviderSchemaType,
|
ApiProviderSchemaType,
|
||||||
|
ProviderConfig,
|
||||||
ToolCredentialsOption,
|
ToolCredentialsOption,
|
||||||
ToolProviderCredentials,
|
|
||||||
)
|
)
|
||||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
@ -39,9 +39,9 @@ class ApiToolManageService:
|
|||||||
raise ValueError(f"invalid schema: {str(e)}")
|
raise ValueError(f"invalid schema: {str(e)}")
|
||||||
|
|
||||||
credentials_schema = [
|
credentials_schema = [
|
||||||
ToolProviderCredentials(
|
ProviderConfig(
|
||||||
name="auth_type",
|
name="auth_type",
|
||||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
type=ProviderConfig.Type.SELECT,
|
||||||
required=True,
|
required=True,
|
||||||
default="none",
|
default="none",
|
||||||
options=[
|
options=[
|
||||||
@ -50,17 +50,17 @@ class ApiToolManageService:
|
|||||||
],
|
],
|
||||||
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
|
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
|
||||||
),
|
),
|
||||||
ToolProviderCredentials(
|
ProviderConfig(
|
||||||
name="api_key_header",
|
name="api_key_header",
|
||||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
type=ProviderConfig.Type.TEXT_INPUT,
|
||||||
required=False,
|
required=False,
|
||||||
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
|
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
|
||||||
default="api_key",
|
default="api_key",
|
||||||
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
|
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
|
||||||
),
|
),
|
||||||
ToolProviderCredentials(
|
ProviderConfig(
|
||||||
name="api_key_value",
|
name="api_key_value",
|
||||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
type=ProviderConfig.Type.TEXT_INPUT,
|
||||||
required=False,
|
required=False,
|
||||||
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
|
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
|
||||||
default="",
|
default="",
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from core.tools.entities.common_entities import I18nObject
|
|||||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ApiProviderAuthType,
|
ApiProviderAuthType,
|
||||||
|
ProviderConfig,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolProviderCredentials,
|
|
||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||||
@ -92,7 +92,7 @@ class ToolTransformService:
|
|||||||
# get credentials schema
|
# get credentials schema
|
||||||
schema = provider_controller.get_credentials_schema()
|
schema = provider_controller.get_credentials_schema()
|
||||||
for name, value in schema.items():
|
for name, value in schema.items():
|
||||||
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
|
result.masked_credentials[name] = ProviderConfig.Type.default(value.type)
|
||||||
|
|
||||||
# check if the provider need credentials
|
# check if the provider need credentials
|
||||||
if not provider_controller.need_credentials:
|
if not provider_controller.need_credentials:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user