refactor: tool
|
|
@ -32,13 +32,13 @@ from core.model_runtime.entities.message_entities import (
|
|||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ from core.model_runtime.entities.message_entities import (
|
|||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
|
|
|
|||
|
|
@ -24,9 +24,7 @@ from core.rag.models.document import Document
|
|||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
|
@ -371,7 +369,7 @@ class DatasetRetrieval:
|
|||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager = (
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
|
|
@ -494,7 +492,8 @@ class DatasetRetrieval:
|
|||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
top_k=top_k,
|
||||
|
|
@ -506,6 +505,7 @@ class DatasetRetrieval:
|
|||
|
||||
tools.append(tool)
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ from typing import Any
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
|
||||
class ToolProviderController(BaseModel, ABC):
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
- code
|
||||
- time
|
||||
- qrcode
|
||||
|
|
@ -6,13 +6,13 @@ from pydantic import Field
|
|||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
|
||||
# load provider yaml
|
||||
provider = self.__class__.__module__.split(".")[-1]
|
||||
yaml_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, f"{provider}.yaml")
|
||||
yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml")
|
||||
try:
|
||||
provider_yaml = load_yaml_file(yaml_path, ignore_error=False)
|
||||
except Exception as e:
|
||||
|
|
@ -52,7 +52,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
return self.tools
|
||||
|
||||
provider = self.identity.name
|
||||
tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools")
|
||||
tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
|
||||
# get all the yaml files in the tool path
|
||||
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
|
||||
tools = []
|
||||
|
|
@ -63,9 +63,10 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
|
||||
# get tool class, import the module
|
||||
assistant_tool_class = load_single_subclass_from_source(
|
||||
module_name=f"core.tools.provider.builtin.{provider}.tools.{tool_name}",
|
||||
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
|
||||
script_path=path.join(
|
||||
path.dirname(path.realpath(__file__)), "builtin", provider, "tools", f"{tool_name}.py"
|
||||
path.dirname(path.realpath(__file__)),
|
||||
"builtin_tool", "providers", provider, "tools", f"{tool_name}.py"
|
||||
),
|
||||
parent_type=BuiltinTool,
|
||||
)
|
||||
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 1.4 KiB |
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Any
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SimpleCode(BuiltinTool):
|
||||
|
Before Width: | Height: | Size: 428 B After Width: | Height: | Size: 428 B |
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class QRCodeProvider(BuiltinToolProviderController):
|
||||
|
|
@ -7,8 +7,8 @@ from qrcode.image.base import BaseImage
|
|||
from qrcode.image.pure import PyPNGImage
|
||||
from qrcode.main import QRCode
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class QRCodeGeneratorTool(BuiltinTool):
|
||||
|
Before Width: | Height: | Size: 691 B After Width: | Height: | Size: 691 B |
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.time.tools.current_time import CurrentTimeTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
|
|
@ -3,8 +3,8 @@ from typing import Any, Union
|
|||
|
||||
from pytz import timezone as pytz_timezone
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CurrentTimeTool(BuiltinTool):
|
||||
|
|
@ -2,8 +2,8 @@ import calendar
|
|||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class WeekdayTool(BuiltinTool):
|
||||
|
|
@ -17,6 +17,8 @@ class WeekdayTool(BuiltinTool):
|
|||
"""
|
||||
year = tool_parameters.get("year")
|
||||
month = tool_parameters.get("month")
|
||||
if month is None:
|
||||
raise ValueError("Month is required")
|
||||
day = tool_parameters.get("day")
|
||||
|
||||
date_obj = self.convert_datetime(year, month, day)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
|
||||
|
|
@ -32,7 +32,7 @@ class BuiltinTool(Tool):
|
|||
# invoke model
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_name=self.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
|
|
@ -124,7 +124,7 @@ class BuiltinTool(Tool):
|
|||
|
||||
return result
|
||||
|
||||
def get_url(self, url: str, user_agent: str = None) -> str:
|
||||
def get_url(self, url: str, user_agent: str | None = None) -> str:
|
||||
"""
|
||||
get url
|
||||
"""
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
from pydantic import Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
|
@ -67,7 +67,8 @@ class ApiToolProviderController(ToolProviderController):
|
|||
else:
|
||||
raise ValueError(f"invalid auth type {auth_type}")
|
||||
|
||||
user_name = db_provider.user.name if db_provider.user_id else ""
|
||||
user = db_provider.user
|
||||
user_name = user.name if user else ""
|
||||
|
||||
return ApiToolProviderController(
|
||||
**{
|
||||
|
|
@ -7,10 +7,10 @@ from urllib.parse import urlencode
|
|||
import httpx
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
API_TOOL_DEFAULT_TIMEOUT = (
|
||||
int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),
|
||||
|
|
@ -33,10 +33,10 @@ class ApiTool(Tool):
|
|||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
identity=self.identity.model_copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
identity=self.identity.model_copy(),
|
||||
parameters=self.parameters.copy() if self.parameters else [],
|
||||
description=self.description.model_copy() if self.description else None,
|
||||
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
|
||||
api_bundle=self.api_bundle.model_copy(),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
)
|
||||
|
||||
|
|
@ -60,6 +60,9 @@ class ApiTool(Tool):
|
|||
return ToolProviderType.API
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
if self.runtime == None:
|
||||
raise ToolProviderCredentialValidationError("runtime not initialized")
|
||||
|
||||
headers = {}
|
||||
credentials = self.runtime.credentials or {}
|
||||
|
||||
|
|
@ -88,7 +91,7 @@ class ApiTool(Tool):
|
|||
|
||||
headers[api_key_header] = credentials["api_key_value"]
|
||||
|
||||
needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
|
||||
needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
|
||||
for parameter in needed_parameters:
|
||||
if parameter.required and parameter.name not in parameters:
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
|
||||
|
|
@ -204,7 +207,7 @@ class ApiTool(Tool):
|
|||
)
|
||||
return response
|
||||
else:
|
||||
raise ValueError(f"Invalid http method {self.method}")
|
||||
raise ValueError(f"Invalid http method {method}")
|
||||
|
||||
def _convert_body_property_any_of(
|
||||
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
|
||||
|
|
@ -4,9 +4,9 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool.tool import ToolParameter
|
||||
|
||||
|
||||
class UserTool(BaseModel):
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class ToolProviderType(str, Enum):
|
|||
Enum class for tool provider
|
||||
"""
|
||||
|
||||
PLUGIN = "plugin"
|
||||
BUILT_IN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
API = "api"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class PluginToolProvider(ToolProviderController):
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
return super().get_tool(tool_name)
|
||||
|
||||
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||
"""
|
||||
get credentials schema
|
||||
"""
|
||||
return super().get_credentials_schema()
|
||||
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
- google
|
||||
- bing
|
||||
- perplexity
|
||||
- duckduckgo
|
||||
- searchapi
|
||||
- serper
|
||||
- searxng
|
||||
- dalle
|
||||
- azuredalle
|
||||
- stability
|
||||
- wikipedia
|
||||
- nominatim
|
||||
- yahoo
|
||||
- alphavantage
|
||||
- arxiv
|
||||
- pubmed
|
||||
- stablediffusion
|
||||
- webscraper
|
||||
- jina
|
||||
- aippt
|
||||
- youtube
|
||||
- code
|
||||
- wolframalpha
|
||||
- maths
|
||||
- github
|
||||
- chart
|
||||
- time
|
||||
- vectorizer
|
||||
- gaode
|
||||
- wecom
|
||||
- qrcode
|
||||
- dingtalk
|
||||
- feishu
|
||||
- feishu_base
|
||||
- feishu_document
|
||||
- feishu_message
|
||||
- slack
|
||||
- tianditu
|
||||
|
|
@ -1,103 +0,0 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import PublishedAppTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppToolProviderEntity(ToolProviderController):
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.APP
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def get_tools(self, user_id: str) -> list[Tool]:
|
||||
db_tools: list[PublishedAppTool] = (
|
||||
db.session.query(PublishedAppTool)
|
||||
.filter(
|
||||
PublishedAppTool.user_id == user_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not db_tools or len(db_tools) == 0:
|
||||
return []
|
||||
|
||||
tools: list[Tool] = []
|
||||
|
||||
for db_tool in db_tools:
|
||||
tool = {
|
||||
"identity": {
|
||||
"author": db_tool.author,
|
||||
"name": db_tool.tool_name,
|
||||
"label": {"en_US": db_tool.tool_name, "zh_Hans": db_tool.tool_name},
|
||||
"icon": "",
|
||||
},
|
||||
"description": {
|
||||
"human": {"en_US": db_tool.description_i18n.en_US, "zh_Hans": db_tool.description_i18n.zh_Hans},
|
||||
"llm": db_tool.llm_description,
|
||||
},
|
||||
"parameters": [],
|
||||
}
|
||||
# get app from db
|
||||
app: App = db_tool.app
|
||||
|
||||
if not app:
|
||||
logger.error(f"app {db_tool.app_id} not found")
|
||||
continue
|
||||
|
||||
app_model_config: AppModelConfig = app.app_model_config
|
||||
user_input_form_list = app_model_config.user_input_form_list
|
||||
for input_form in user_input_form_list:
|
||||
# get type
|
||||
form_type = input_form.keys()[0]
|
||||
default = input_form[form_type]["default"]
|
||||
required = input_form[form_type]["required"]
|
||||
label = input_form[form_type]["label"]
|
||||
variable_name = input_form[form_type]["variable_name"]
|
||||
options = input_form[form_type].get("options", [])
|
||||
if form_type in {"paragraph", "text-input"}:
|
||||
tool["parameters"].append(
|
||||
ToolParameter(
|
||||
name=variable_name,
|
||||
label=I18nObject(en_US=label, zh_Hans=label),
|
||||
human_description=I18nObject(en_US=label, zh_Hans=label),
|
||||
llm_description=label,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=required,
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
elif form_type == "select":
|
||||
tool["parameters"].append(
|
||||
ToolParameter(
|
||||
name=variable_name,
|
||||
label=I18nObject(en_US=label, zh_Hans=label),
|
||||
human_description=I18nObject(en_US=label, zh_Hans=label),
|
||||
llm_description=label,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
required=required,
|
||||
default=default,
|
||||
options=[
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
tools.append(Tool(**tool))
|
||||
return tools
|
||||
|
Before Width: | Height: | Size: 1.9 KiB |
|
|
@ -1,11 +0,0 @@
|
|||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AIPPTProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
AIPPTGenerateTool._get_api_token(credentials, user_id="__dify_system__")
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
identity:
|
||||
author: Dify
|
||||
name: aippt
|
||||
label:
|
||||
en_US: AIPPT
|
||||
zh_Hans: AIPPT
|
||||
description:
|
||||
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
|
||||
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
|
||||
icon: icon.png
|
||||
tags:
|
||||
- productivity
|
||||
- design
|
||||
credentials_for_provider:
|
||||
aippt_access_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AIPPT API key
|
||||
zh_Hans: AIPPT API key
|
||||
pt_BR: AIPPT API key
|
||||
help:
|
||||
en_US: Please input your AIPPT API key
|
||||
zh_Hans: 请输入你的 AIPPT API key
|
||||
pt_BR: Please input your AIPPT API key
|
||||
placeholder:
|
||||
en_US: Please input your AIPPT API key
|
||||
zh_Hans: 请输入你的 AIPPT API key
|
||||
pt_BR: Please input your AIPPT API key
|
||||
url: https://www.aippt.cn
|
||||
aippt_secret_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AIPPT Secret key
|
||||
zh_Hans: AIPPT Secret key
|
||||
pt_BR: AIPPT Secret key
|
||||
help:
|
||||
en_US: Please input your AIPPT Secret key
|
||||
zh_Hans: 请输入你的 AIPPT Secret key
|
||||
pt_BR: Please input your AIPPT Secret key
|
||||
placeholder:
|
||||
en_US: Please input your AIPPT Secret key
|
||||
zh_Hans: 请输入你的 AIPPT Secret key
|
||||
pt_BR: Please input your AIPPT Secret key
|
||||
|
|
@ -1,498 +0,0 @@
|
|||
from base64 import b64encode
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from json import loads as json_loads
|
||||
from threading import Lock
|
||||
from time import sleep, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from httpx import get, post
|
||||
from requests import get as requests_get
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AIPPTGenerateTool(BuiltinTool):
|
||||
"""
|
||||
A tool for generating a ppt
|
||||
"""
|
||||
|
||||
_api_base_url = URL("https://co.aippt.cn/api")
|
||||
_api_token_cache = {}
|
||||
_api_token_cache_lock: Optional[Lock] = None
|
||||
_style_cache = {}
|
||||
_style_cache_lock: Optional[Lock] = None
|
||||
|
||||
_task = {}
|
||||
_task_type_map = {
|
||||
"auto": 1,
|
||||
"markdown": 7,
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._api_token_cache_lock = Lock()
|
||||
self._style_cache_lock = Lock()
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the AIPPT generate tool with the given user ID and tool parameters.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
|
||||
which can be a single message or a list of messages.
|
||||
"""
|
||||
title = tool_parameters.get("title", "")
|
||||
if not title:
|
||||
return self.create_text_message("Please provide a title for the ppt")
|
||||
|
||||
model = tool_parameters.get("model", "aippt")
|
||||
if not model:
|
||||
return self.create_text_message("Please provide a model for the ppt")
|
||||
|
||||
outline = tool_parameters.get("outline", "")
|
||||
|
||||
# create task
|
||||
task_id = self._create_task(
|
||||
type=self._task_type_map["auto" if not outline else "markdown"],
|
||||
title=title,
|
||||
content=outline,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# get suit
|
||||
color = tool_parameters.get("color")
|
||||
style = tool_parameters.get("style")
|
||||
|
||||
if color == "__default__":
|
||||
color_id = ""
|
||||
else:
|
||||
color_id = int(color.split("-")[1])
|
||||
|
||||
if style == "__default__":
|
||||
style_id = ""
|
||||
else:
|
||||
style_id = int(style.split("-")[1])
|
||||
|
||||
suit_id = self._get_suit(style_id=style_id, colour_id=color_id)
|
||||
|
||||
# generate outline
|
||||
if not outline:
|
||||
self._generate_outline(task_id=task_id, model=model, user_id=user_id)
|
||||
|
||||
# generate content
|
||||
self._generate_content(task_id=task_id, model=model, user_id=user_id)
|
||||
|
||||
# generate ppt
|
||||
_, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id)
|
||||
|
||||
return self.create_text_message(
|
||||
"""the ppt has been created successfully,"""
|
||||
f"""the ppt url is {ppt_url}"""
|
||||
"""please give the ppt url to user and direct user to download it."""
|
||||
)
|
||||
|
||||
def _create_task(self, type: int, title: str, content: str, user_id: str) -> str:
|
||||
"""
|
||||
Create a task
|
||||
|
||||
:param type: the task type
|
||||
:param title: the task title
|
||||
:param content: the task content
|
||||
|
||||
:return: the task ID
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
response = post(
|
||||
str(self._api_base_url / "ai" / "chat" / "v2" / "task"),
|
||||
headers=headers,
|
||||
files={"type": ("", str(type)), "title": ("", title), "content": ("", content)},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to create task: {response.get("msg")}')
|
||||
|
||||
return response.get("data", {}).get("id")
|
||||
|
||||
def _generate_outline(self, task_id: str, model: str, user_id: str) -> str:
|
||||
api_url = (
|
||||
self._api_base_url / "ai" / "chat" / "outline"
|
||||
if model == "aippt"
|
||||
else self._api_base_url / "ai" / "chat" / "wx" / "outline"
|
||||
)
|
||||
api_url %= {"task_id": task_id}
|
||||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
outline = ""
|
||||
for chunk in response.iter_lines(delimiter=b"\n\n"):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
event = ""
|
||||
lines = chunk.decode("utf-8").split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("event:"):
|
||||
event = line[6:]
|
||||
elif line.startswith("data:"):
|
||||
data = line[5:]
|
||||
if event == "message":
|
||||
try:
|
||||
data = json_loads(data)
|
||||
outline += data.get("content", "")
|
||||
except Exception as e:
|
||||
pass
|
||||
elif event == "close":
|
||||
break
|
||||
elif event in {"error", "filter"}:
|
||||
raise Exception(f"Failed to generate outline: {data}")
|
||||
|
||||
return outline
|
||||
|
||||
def _generate_content(self, task_id: str, model: str, user_id: str) -> str:
|
||||
api_url = (
|
||||
self._api_base_url / "ai" / "chat" / "content"
|
||||
if model == "aippt"
|
||||
else self._api_base_url / "ai" / "chat" / "wx" / "content"
|
||||
)
|
||||
api_url %= {"task_id": task_id}
|
||||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
if model == "aippt":
|
||||
content = ""
|
||||
for chunk in response.iter_lines(delimiter=b"\n\n"):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
event = ""
|
||||
lines = chunk.decode("utf-8").split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("event:"):
|
||||
event = line[6:]
|
||||
elif line.startswith("data:"):
|
||||
data = line[5:]
|
||||
if event == "message":
|
||||
try:
|
||||
data = json_loads(data)
|
||||
content += data.get("content", "")
|
||||
except Exception as e:
|
||||
pass
|
||||
elif event == "close":
|
||||
break
|
||||
elif event in {"error", "filter"}:
|
||||
raise Exception(f"Failed to generate content: {data}")
|
||||
|
||||
return content
|
||||
elif model == "wenxin":
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate content: {response.get("msg")}')
|
||||
|
||||
return response.get("data", "")
|
||||
|
||||
return ""
|
||||
|
||||
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a ppt
|
||||
|
||||
:param task_id: the task ID
|
||||
:param suit_id: the suit ID
|
||||
:return: the cover url of the ppt and the ppt url
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / "design" / "v2" / "save"),
|
||||
headers=headers,
|
||||
data={"task_id": task_id, "template_id": suit_id},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
id = response.get("data", {}).get("id")
|
||||
cover_url = response.get("data", {}).get("cover_url")
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / "download" / "export" / "file"),
|
||||
headers=headers,
|
||||
data={"id": id, "format": "ppt", "files_to_zip": False, "edit": True},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
export_code = response.get("data")
|
||||
if not export_code:
|
||||
raise Exception("Failed to generate ppt, the export code is empty")
|
||||
|
||||
current_iteration = 0
|
||||
while current_iteration < 50:
|
||||
# get ppt url
|
||||
response = post(
|
||||
str(self._api_base_url / "download" / "export" / "file" / "result"),
|
||||
headers=headers,
|
||||
data={"task_key": export_code},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
if response.get("msg") == "导出中":
|
||||
current_iteration += 1
|
||||
sleep(2)
|
||||
continue
|
||||
|
||||
ppt_url = response.get("data", [])
|
||||
if len(ppt_url) == 0:
|
||||
raise Exception("Failed to generate ppt, the ppt url is empty")
|
||||
|
||||
return cover_url, ppt_url[0]
|
||||
|
||||
raise Exception("Failed to generate ppt, the export is timeout")
|
||||
|
||||
@classmethod
|
||||
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
|
||||
"""
|
||||
Get API token
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: the API token
|
||||
"""
|
||||
access_key = credentials["aippt_access_key"]
|
||||
secret_key = credentials["aippt_secret_key"]
|
||||
|
||||
cache_key = f"{access_key}#@#{user_id}"
|
||||
|
||||
with cls._api_token_cache_lock:
|
||||
# clear expired tokens
|
||||
now = time()
|
||||
for key in list(cls._api_token_cache.keys()):
|
||||
if cls._api_token_cache[key]["expire"] < now:
|
||||
del cls._api_token_cache[key]
|
||||
|
||||
if cache_key in cls._api_token_cache:
|
||||
return cls._api_token_cache[cache_key]["token"]
|
||||
|
||||
# get token
|
||||
headers = {
|
||||
"x-api-key": access_key,
|
||||
"x-timestamp": str(int(now)),
|
||||
"x-signature": cls._calculate_sign(access_key, secret_key, int(now)),
|
||||
}
|
||||
|
||||
param = {"uid": user_id, "channel": ""}
|
||||
|
||||
response = get(str(cls._api_base_url / "grant" / "token"), params=param, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
token = response.get("data", {}).get("token")
|
||||
expire = response.get("data", {}).get("time_expire")
|
||||
|
||||
with cls._api_token_cache_lock:
|
||||
cls._api_token_cache[cache_key] = {"token": token, "expire": now + expire}
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
|
||||
return b64encode(
|
||||
hmac_new(
|
||||
key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1
|
||||
).digest()
|
||||
).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Get styles
|
||||
"""
|
||||
|
||||
# check cache
|
||||
with cls._style_cache_lock:
|
||||
# clear expired styles
|
||||
now = time()
|
||||
for key in list(cls._style_cache.keys()):
|
||||
if cls._style_cache[key]["expire"] < now:
|
||||
del cls._style_cache[key]
|
||||
|
||||
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
|
||||
if key in cls._style_cache:
|
||||
return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"]
|
||||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": credentials["aippt_access_key"],
|
||||
"x-token": cls._get_api_token(credentials=credentials, user_id=user_id),
|
||||
}
|
||||
response = get(str(cls._api_base_url / "template_component" / "suit" / "select"), headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
colors = [
|
||||
{
|
||||
"id": f'id-{item.get("id")}',
|
||||
"name": item.get("name"),
|
||||
"en_name": item.get("en_name", item.get("name")),
|
||||
}
|
||||
for item in response.get("data", {}).get("colour") or []
|
||||
]
|
||||
styles = [
|
||||
{
|
||||
"id": f'id-{item.get("id")}',
|
||||
"name": item.get("title"),
|
||||
}
|
||||
for item in response.get("data", {}).get("suit_style") or []
|
||||
]
|
||||
|
||||
with cls._style_cache_lock:
|
||||
cls._style_cache[key] = {"colors": colors, "styles": styles, "expire": now + 60 * 60}
|
||||
|
||||
return colors, styles
|
||||
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Get styles
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||
"""
|
||||
if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"):
|
||||
raise Exception("Please provide aippt credentials")
|
||||
|
||||
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
|
||||
|
||||
def _get_suit(self, style_id: int, colour_id: int) -> int:
|
||||
"""
|
||||
Get suit
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"),
|
||||
}
|
||||
response = get(
|
||||
str(self._api_base_url / "template_component" / "suit" / "search"),
|
||||
headers=headers,
|
||||
params={"style_id": style_id, "colour_id": colour_id, "page": 1, "page_size": 1},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
if len(response.get("data", {}).get("list") or []) > 0:
|
||||
return response.get("data", {}).get("list")[0].get("id")
|
||||
|
||||
raise Exception("Failed to get suit, the suit does not exist, please check the style and color")
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
Get runtime parameters
|
||||
|
||||
Override this method to add runtime parameters to the tool.
|
||||
"""
|
||||
try:
|
||||
colors, styles = self.get_styles(user_id="__dify_system__")
|
||||
except Exception as e:
|
||||
colors, styles = (
|
||||
[{"id": "-1", "name": "__default__", "en_name": "__default__"}],
|
||||
[{"id": "-1", "name": "__default__", "en_name": "__default__"}],
|
||||
)
|
||||
|
||||
return [
|
||||
ToolParameter(
|
||||
name="color",
|
||||
label=I18nObject(zh_Hans="颜色", en_US="Color"),
|
||||
human_description=I18nObject(zh_Hans="颜色", en_US="Color"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=False,
|
||||
default=colors[0]["id"],
|
||||
options=[
|
||||
ToolParameterOption(
|
||||
value=color["id"], label=I18nObject(zh_Hans=color["name"], en_US=color["en_name"])
|
||||
)
|
||||
for color in colors
|
||||
],
|
||||
),
|
||||
ToolParameter(
|
||||
name="style",
|
||||
label=I18nObject(zh_Hans="风格", en_US="Style"),
|
||||
human_description=I18nObject(zh_Hans="风格", en_US="Style"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=False,
|
||||
default=styles[0]["id"],
|
||||
options=[
|
||||
ToolParameterOption(value=style["id"], label=I18nObject(zh_Hans=style["name"], en_US=style["name"]))
|
||||
for style in styles
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
identity:
|
||||
name: aippt
|
||||
author: Dify
|
||||
label:
|
||||
en_US: AIPPT
|
||||
zh_Hans: AIPPT
|
||||
description:
|
||||
human:
|
||||
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
|
||||
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
|
||||
llm: A tool used to generate PPT with AI, input your content topic, and let AI generate PPT for you.
|
||||
parameters:
|
||||
- name: title
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Title
|
||||
zh_Hans: 标题
|
||||
human_description:
|
||||
en_US: The title of the PPT.
|
||||
zh_Hans: PPT的标题。
|
||||
llm_description: The title of the PPT, which will be used to generate the PPT outline.
|
||||
form: llm
|
||||
- name: outline
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Outline
|
||||
zh_Hans: 大纲
|
||||
human_description:
|
||||
en_US: The outline of the PPT
|
||||
zh_Hans: PPT的大纲
|
||||
llm_description: The outline of the PPT, which will be used to generate the PPT content. provide it if you have.
|
||||
form: llm
|
||||
- name: llm
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: LLM model
|
||||
zh_Hans: 生成大纲的LLM
|
||||
options:
|
||||
- value: aippt
|
||||
label:
|
||||
en_US: AIPPT default model
|
||||
zh_Hans: AIPPT默认模型
|
||||
- value: wenxin
|
||||
label:
|
||||
en_US: Wenxin ErnieBot
|
||||
zh_Hans: 文心一言
|
||||
default: aippt
|
||||
human_description:
|
||||
en_US: The LLM model used for generating PPT outline.
|
||||
zh_Hans: 用于生成PPT大纲的LLM模型。
|
||||
form: form
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="56px" height="56px" viewBox="0 0 56 56" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>形状结合</title>
|
||||
<g id="设计规范" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<path d="M56,0 L56,56 L0,56 L0,0 L56,0 Z M31.6063018,12 L24.3936982,12 L24.1061064,12.7425499 L12.6071308,42.4324141 L12,44 L19.7849972,44 L20.0648488,43.2391815 L22.5196173,36.5567427 L33.4780427,36.5567427 L35.9351512,43.2391815 L36.2150028,44 L44,44 L43.3928692,42.4324141 L31.8938936,12.7425499 L31.6063018,12 Z M28.0163803,21.5755126 L31.1613993,30.2523823 L24.8432808,30.2523823 L28.0163803,21.5755126 Z" id="形状结合" fill="#2F4F4F"></path>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 780 B |
|
|
@ -1,22 +0,0 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.alphavantage.tools.query_stock import QueryStockTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AlphaVantageProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
QueryStockTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"code": "AAPL", # Apple Inc.
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
identity:
|
||||
author: zhuhao
|
||||
name: alphavantage
|
||||
label:
|
||||
en_US: AlphaVantage
|
||||
zh_Hans: AlphaVantage
|
||||
pt_BR: AlphaVantage
|
||||
description:
|
||||
en_US: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis.
|
||||
zh_Hans: AlphaVantage是一个在线平台,它提供金融市场数据和API,便于个人投资者和开发者获取股票报价、技术指标和股票分析。
|
||||
pt_BR: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- finance
|
||||
credentials_for_provider:
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AlphaVantage API key
|
||||
zh_Hans: AlphaVantage API key
|
||||
pt_BR: AlphaVantage API key
|
||||
placeholder:
|
||||
en_US: Please input your AlphaVantage API key
|
||||
zh_Hans: 请输入你的 AlphaVantage API key
|
||||
pt_BR: Please input your AlphaVantage API key
|
||||
help:
|
||||
en_US: Get your AlphaVantage API key from AlphaVantage
|
||||
zh_Hans: 从 AlphaVantage 获取您的 AlphaVantage API key
|
||||
pt_BR: Get your AlphaVantage API key from AlphaVantage
|
||||
url: https://www.alphavantage.co/support/#api-key
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query"
|
||||
|
||||
|
||||
class QueryStockTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
stock_code = tool_parameters.get("code", "")
|
||||
if not stock_code:
|
||||
return self.create_text_message("Please tell me your stock code")
|
||||
|
||||
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
|
||||
return self.create_text_message("Alpha Vantage API key is required.")
|
||||
|
||||
params = {
|
||||
"function": "TIME_SERIES_DAILY",
|
||||
"symbol": stock_code,
|
||||
"outputsize": "compact",
|
||||
"datatype": "json",
|
||||
"apikey": self.runtime.credentials["api_key"],
|
||||
}
|
||||
response = requests.get(url=ALPHAVANTAGE_API_URL, params=params)
|
||||
response.raise_for_status()
|
||||
result = self._handle_response(response.json())
|
||||
return self.create_json_message(result)
|
||||
|
||||
def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
result = response.get("Time Series (Daily)", {})
|
||||
if not result:
|
||||
return {}
|
||||
stock_result = {}
|
||||
for k, v in result.items():
|
||||
stock_result[k] = {}
|
||||
stock_result[k]["open"] = v.get("1. open")
|
||||
stock_result[k]["high"] = v.get("2. high")
|
||||
stock_result[k]["low"] = v.get("3. low")
|
||||
stock_result[k]["close"] = v.get("4. close")
|
||||
stock_result[k]["volume"] = v.get("5. volume")
|
||||
return stock_result
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
identity:
|
||||
name: query_stock
|
||||
author: zhuhao
|
||||
label:
|
||||
en_US: query_stock
|
||||
zh_Hans: query_stock
|
||||
pt_BR: query_stock
|
||||
description:
|
||||
human:
|
||||
en_US: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol.
|
||||
zh_Hans: 获取指定股票代码的每日开盘价、每日最高价、每日最低价、每日收盘价和每日交易量等信息。
|
||||
pt_BR: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol
|
||||
llm: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol
|
||||
parameters:
|
||||
- name: code
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: stock code
|
||||
zh_Hans: 股票代码
|
||||
pt_BR: stock code
|
||||
human_description:
|
||||
en_US: stock code
|
||||
zh_Hans: 股票代码
|
||||
pt_BR: stock code
|
||||
llm_description: stock code for query from alphavantage
|
||||
form: llm
|
||||
|
|
@ -1 +0,0 @@
|
|||
<svg id="logomark" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 17.732 24.269"><g id="tiny"><path d="M573.549,280.916l2.266,2.738,6.674-7.84c.353-.47.52-.717.353-1.117a1.218,1.218,0,0,0-1.061-.748h0a.953.953,0,0,0-.712.262Z" transform="translate(-566.984 -271.548)" fill="#bdb9b4"/><path d="M579.525,282.225l-10.606-10.174a1.413,1.413,0,0,0-.834-.5,1.09,1.09,0,0,0-1.027.66c-.167.4-.047.681.319,1.206l8.44,10.242h0l-6.282,7.716a1.336,1.336,0,0,0-.323,1.3,1.114,1.114,0,0,0,1.04.69A.992.992,0,0,0,571,293l8.519-7.92A1.924,1.924,0,0,0,579.525,282.225Z" transform="translate(-566.984 -271.548)" fill="#b31b1b"/><path d="M584.32,293.912l-8.525-10.275,0,0L573.53,280.9l-1.389,1.254a2.063,2.063,0,0,0,0,2.965l10.812,10.419a.925.925,0,0,0,.742.282,1.039,1.039,0,0,0,.953-.667A1.261,1.261,0,0,0,584.32,293.912Z" transform="translate(-566.984 -271.548)" fill="#bdb9b4"/></g></svg>
|
||||
|
Before Width: | Height: | Size: 874 B |
|
|
@ -1,20 +0,0 @@
|
|||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.arxiv.tools.arxiv_search import ArxivSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class ArxivProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
ArxivSearchTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"query": "John Doe",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
identity:
|
||||
author: Yash Parmar
|
||||
name: arxiv
|
||||
label:
|
||||
en_US: ArXiv
|
||||
zh_Hans: ArXiv
|
||||
description:
|
||||
en_US: Access to a vast repository of scientific papers and articles in various fields of research.
|
||||
zh_Hans: 访问各个研究领域大量科学论文和文章的存储库。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
|
|
@ -1,119 +0,0 @@
|
|||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import arxiv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArxivAPIWrapper(BaseModel):
|
||||
"""Wrapper around ArxivAPI.
|
||||
|
||||
To use, you should have the ``arxiv`` python package installed.
|
||||
https://lukasschwab.me/arxiv.py/index.html
|
||||
This wrapper will use the Arxiv API to conduct searches and
|
||||
fetch document summaries. By default, it will return the document summaries
|
||||
of the top-k results.
|
||||
It limits the Document content by doc_content_chars_max.
|
||||
Set doc_content_chars_max=None if you don't want to limit the content size.
|
||||
|
||||
Args:
|
||||
top_k_results: number of the top-scored document used for the arxiv tool
|
||||
ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
|
||||
load_max_docs: a limit to the number of loaded documents
|
||||
load_all_available_meta:
|
||||
if True: the `metadata` of the loaded Documents contains all available
|
||||
meta info (see https://lukasschwab.me/arxiv.py/index.html#Result),
|
||||
if False: the `metadata` contains only the published date, title,
|
||||
authors and summary.
|
||||
doc_content_chars_max: an optional cut limit for the length of a document's
|
||||
content
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
arxiv = ArxivAPIWrapper(
|
||||
top_k_results = 3,
|
||||
ARXIV_MAX_QUERY_LENGTH = 300,
|
||||
load_max_docs = 3,
|
||||
load_all_available_meta = False,
|
||||
doc_content_chars_max = 40000
|
||||
)
|
||||
arxiv.run("tree of thought llm)
|
||||
"""
|
||||
|
||||
arxiv_search: type[arxiv.Search] = arxiv.Search #: :meta private:
|
||||
arxiv_http_error: tuple[type[Exception]] = (arxiv.ArxivError, arxiv.UnexpectedEmptyPageError, arxiv.HTTPError)
|
||||
top_k_results: int = 3
|
||||
ARXIV_MAX_QUERY_LENGTH: int = 300
|
||||
load_max_docs: int = 100
|
||||
load_all_available_meta: bool = False
|
||||
doc_content_chars_max: Optional[int] = 4000
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""
|
||||
Performs an arxiv search and A single string
|
||||
with the publish date, title, authors, and summary
|
||||
for each article separated by two newlines.
|
||||
|
||||
If an error occurs or no documents found, error text
|
||||
is returned instead. Wrapper for
|
||||
https://lukasschwab.me/arxiv.py/index.html#Search
|
||||
|
||||
Args:
|
||||
query: a plaintext search query
|
||||
"""
|
||||
try:
|
||||
results = self.arxiv_search( # type: ignore
|
||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
|
||||
).results()
|
||||
except arxiv_http_error as ex:
|
||||
return f"Arxiv exception: {ex}"
|
||||
docs = [
|
||||
f"Published: {result.updated.date()}\n"
|
||||
f"Title: {result.title}\n"
|
||||
f"Authors: {', '.join(a.name for a in result.authors)}\n"
|
||||
f"Summary: {result.summary}"
|
||||
for result in results
|
||||
]
|
||||
if docs:
|
||||
return "\n\n".join(docs)[: self.doc_content_chars_max]
|
||||
else:
|
||||
return "No good Arxiv Result was found"
|
||||
|
||||
|
||||
class ArxivSearchInput(BaseModel):
|
||||
query: str = Field(..., description="Search query.")
|
||||
|
||||
|
||||
class ArxivSearchTool(BuiltinTool):
|
||||
"""
|
||||
A tool for searching articles on Arxiv.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the Arxiv search tool with the given user ID and tool parameters.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool, including the 'query' parameter.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
|
||||
which can be a single message or a list of messages.
|
||||
"""
|
||||
query = tool_parameters.get("query", "")
|
||||
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
arxiv = ArxivAPIWrapper()
|
||||
|
||||
response = arxiv.run(query)
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=response))
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
identity:
|
||||
name: arxiv_search
|
||||
author: Yash Parmar
|
||||
label:
|
||||
en_US: Arxiv Search
|
||||
zh_Hans: Arxiv 搜索
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for searching scientific papers and articles from the Arxiv repository. Input can be an Arxiv ID or an author's name.
|
||||
zh_Hans: 一个用于从Arxiv存储库搜索科学论文和文章的工具。 输入可以是Arxiv ID或作者姓名。
|
||||
llm: A tool for searching scientific papers and articles from the Arxiv repository. Input can be an Arxiv ID or an author's name.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询字符串
|
||||
human_description:
|
||||
en_US: The Arxiv ID or author's name used for searching.
|
||||
zh_Hans: 用于搜索的Arxiv ID或作者姓名。
|
||||
llm_description: The Arxiv ID or author's name used for searching.
|
||||
form: llm
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 16 16" xmlns="http://www.w3.org/2000/svg" fill="none">
|
||||
<path fill="#252F3E" d="M4.51 7.687c0 .197.02.357.058.475.042.117.096.245.17.384a.233.233 0 01.037.123c0 .053-.032.107-.1.16l-.336.224a.255.255 0 01-.138.048c-.054 0-.107-.026-.16-.074a1.652 1.652 0 01-.192-.251 4.137 4.137 0 01-.165-.315c-.415.491-.936.737-1.564.737-.447 0-.804-.129-1.064-.385-.261-.256-.394-.598-.394-1.025 0-.454.16-.822.484-1.1.325-.278.756-.416 1.304-.416.18 0 .367.016.564.042.197.027.4.07.612.118v-.39c0-.406-.085-.689-.25-.854-.17-.166-.458-.246-.868-.246-.186 0-.377.022-.574.07a4.23 4.23 0 00-.575.181 1.525 1.525 0 01-.186.07.326.326 0 01-.085.016c-.075 0-.112-.054-.112-.166v-.262c0-.085.01-.15.037-.186a.399.399 0 01.15-.113c.185-.096.409-.176.67-.24.26-.07.537-.101.83-.101.633 0 1.096.144 1.394.432.293.288.442.726.442 1.314v1.73h.01zm-2.161.811c.175 0 .356-.032.548-.096.191-.064.362-.182.505-.342a.848.848 0 00.181-.341c.032-.129.054-.283.054-.465V7.03a4.43 4.43 0 00-.49-.09 3.996 3.996 0 00-.5-.033c-.357 0-.618.07-.793.214-.176.144-.26.347-.26.614 0 .25.063.437.196.566.128.133.314.197.559.197zm4.273.577c-.096 0-.16-.016-.202-.054-.043-.032-.08-.106-.112-.208l-1.25-4.127a.938.938 0 01-.049-.214c0-.085.043-.133.128-.133h.522c.1 0 .17.016.207.053.043.032.075.107.107.208l.894 3.535.83-3.535c.026-.106.058-.176.1-.208a.365.365 0 01.214-.053h.425c.102 0 .17.016.213.053.043.032.08.107.101.208l.841 3.578.92-3.578a.458.458 0 01.107-.208.346.346 0 01.208-.053h.495c.085 0 .133.043.133.133 0 .027-.006.054-.01.086a.76.76 0 01-.038.133l-1.283 4.127c-.032.107-.069.177-.111.209a.34.34 0 01-.203.053h-.457c-.101 0-.17-.016-.213-.053-.043-.038-.08-.107-.101-.214L8.213 5.37l-.82 3.439c-.026.107-.058.176-.1.213-.043.038-.118.054-.213.054h-.458zm6.838.144a3.51 3.51 0 01-.82-.096c-.266-.064-.473-.134-.612-.214-.085-.048-.143-.101-.165-.15a.378.378 0 01-.031-.149v-.272c0-.112.042-.166.122-.166a.3.3 0 01.096.016c.032.011.08.032.133.054.18.08.378.144.585.187.213.042.42.064.633.064.336 0 .596-.059.777-.176a.575.575 0 00.277-.508.52.52 0 00-.144-.373c-.095-.102-.276-.193-.537-.278l-.772-.24c-.388-.123-.676-.305-.851-.545a1.275 1.275 0 01-.266-.774c0-.224.048-.422.143-.593.096-.17.224-.32.384-.438.16-.122.34-.213.553-.277.213-.064.436-.091.67-.091.118 0 .24.005.357.021.122.016.234.038.346.06.106.026.208.052.303.085.096.032.17.064.224.096a.46.46 0 01.16.133.289.289 0 01.047.176v.251c0 .112-.042.171-.122.171a.552.552 0 01-.202-.064 2.427 2.427 0 00-1.022-.208c-.303 0-.543.048-.708.15-.165.1-.25.256-.25.475 0 .149.053.277.16.379.106.101.303.202.585.293l.756.24c.383.123.66.294.825.513.165.219.244.47.244.748 0 .23-.047.437-.138.619a1.436 1.436 0 01-.388.47c-.165.133-.362.23-.591.299-.24.075-.49.112-.761.112z"/>
|
||||
<g fill="#F90" fill-rule="evenodd" clip-rule="evenodd">
|
||||
<path d="M14.465 11.813c-1.75 1.297-4.294 1.986-6.481 1.986-3.065 0-5.827-1.137-7.913-3.027-.165-.15-.016-.353.18-.235 2.257 1.313 5.04 2.109 7.92 2.109 1.941 0 4.075-.406 6.039-1.239.293-.133.543.192.255.406z"/>
|
||||
<path d="M15.194 10.98c-.223-.287-1.479-.138-2.048-.069-.17.022-.197-.128-.043-.24 1-.705 2.645-.502 2.836-.267.192.24-.053 1.89-.99 2.68-.143.123-.281.06-.218-.1.213-.53.687-1.72.463-2.003z"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 3.3 KiB |
|
|
@ -1,24 +0,0 @@
|
|||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.aws.tools.sagemaker_text_rerank import SageMakerReRankTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class SageMakerProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
SageMakerReRankTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"sagemaker_endpoint": "",
|
||||
"query": "misaka mikoto",
|
||||
"candidate_texts": "hello$$$hello world",
|
||||
"topk": 5,
|
||||
"aws_region": "",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
identity:
|
||||
author: AWS
|
||||
name: aws
|
||||
label:
|
||||
en_US: AWS
|
||||
zh_Hans: 亚马逊云科技
|
||||
pt_BR: AWS
|
||||
description:
|
||||
en_US: Services on AWS.
|
||||
zh_Hans: 亚马逊云科技的各类服务
|
||||
pt_BR: Services on AWS.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuardrailParameters(BaseModel):
|
||||
guardrail_id: str = Field(..., description="The identifier of the guardrail")
|
||||
guardrail_version: str = Field(..., description="The version of the guardrail")
|
||||
source: str = Field(..., description="The source of the content")
|
||||
text: str = Field(..., description="The text to apply the guardrail to")
|
||||
aws_region: str = Field(..., description="AWS region for the Bedrock client")
|
||||
|
||||
|
||||
class ApplyGuardrailTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke the ApplyGuardrail tool
|
||||
"""
|
||||
try:
|
||||
# Validate and parse input parameters
|
||||
params = GuardrailParameters(**tool_parameters)
|
||||
|
||||
# Initialize AWS client
|
||||
bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region)
|
||||
|
||||
# Apply guardrail
|
||||
response = bedrock_client.apply_guardrail(
|
||||
guardrailIdentifier=params.guardrail_id,
|
||||
guardrailVersion=params.guardrail_version,
|
||||
source=params.source,
|
||||
content=[{"text": {"text": params.text}}],
|
||||
)
|
||||
|
||||
logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
|
||||
|
||||
# Check for empty response
|
||||
if not response:
|
||||
return self.create_text_message(text="Received empty response from AWS Bedrock.")
|
||||
|
||||
# Process the result
|
||||
action = response.get("action", "No action specified")
|
||||
outputs = response.get("outputs", [])
|
||||
output = outputs[0].get("text", "No output received") if outputs else "No output received"
|
||||
assessments = response.get("assessments", [])
|
||||
|
||||
# Format assessments
|
||||
formatted_assessments = []
|
||||
for assessment in assessments:
|
||||
for policy_type, policy_data in assessment.items():
|
||||
if isinstance(policy_data, dict) and "topics" in policy_data:
|
||||
for topic in policy_data["topics"]:
|
||||
formatted_assessments.append(
|
||||
f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']},"
|
||||
f" Action: {topic['action']}"
|
||||
)
|
||||
else:
|
||||
formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}")
|
||||
|
||||
result = f"Action: {action}\n "
|
||||
result += f"Output: {output}\n "
|
||||
if formatted_assessments:
|
||||
result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n "
|
||||
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except BotoCoreError as e:
|
||||
error_message = f"AWS service error: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
except json.JSONDecodeError as e:
|
||||
error_message = f"JSON parsing error: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
except Exception as e:
|
||||
error_message = f"An unexpected error occurred: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
identity:
|
||||
name: apply_guardrail
|
||||
author: AWS
|
||||
label:
|
||||
en_US: Content Moderation Guardrails
|
||||
zh_Hans: 内容审查护栏
|
||||
description:
|
||||
human:
|
||||
en_US: Content Moderation Guardrails utilizes the ApplyGuardrail API, a feature of Guardrails for Amazon Bedrock. This API is capable of evaluating input prompts and model responses for all Foundation Models (FMs), including those on Amazon Bedrock, custom FMs, and third-party FMs. By implementing this functionality, organizations can achieve centralized governance across all their generative AI applications, thereby enhancing control and consistency in content moderation.
|
||||
zh_Hans: 内容审查护栏采用 Guardrails for Amazon Bedrock 功能中的 ApplyGuardrail API 。ApplyGuardrail 可以评估所有基础模型(FMs)的输入提示和模型响应,包括 Amazon Bedrock 上的 FMs、自定义 FMs 和第三方 FMs。通过实施这一功能, 组织可以在所有生成式 AI 应用程序中实现集中化的治理,从而增强内容审核的控制力和一致性。
|
||||
llm: Content Moderation Guardrails utilizes the ApplyGuardrail API, a feature of Guardrails for Amazon Bedrock. This API is capable of evaluating input prompts and model responses for all Foundation Models (FMs), including those on Amazon Bedrock, custom FMs, and third-party FMs. By implementing this functionality, organizations can achieve centralized governance across all their generative AI applications, thereby enhancing control and consistency in content moderation.
|
||||
parameters:
|
||||
- name: guardrail_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Guardrail ID
|
||||
zh_Hans: Guardrail ID
|
||||
human_description:
|
||||
en_US: Please enter the ID of the Guardrail that has already been created on Amazon Bedrock, for example 'qk5nk0e4b77b'.
|
||||
zh_Hans: 请输入已经在 Amazon Bedrock 上创建好的 Guardrail ID, 例如 'qk5nk0e4b77b'.
|
||||
llm_description: Please enter the ID of the Guardrail that has already been created on Amazon Bedrock, for example 'qk5nk0e4b77b'.
|
||||
form: form
|
||||
- name: guardrail_version
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Guardrail Version Number
|
||||
zh_Hans: Guardrail 版本号码
|
||||
human_description:
|
||||
en_US: Please enter the published version of the Guardrail ID that has already been created on Amazon Bedrock. This is typically a version number, such as 2.
|
||||
zh_Hans: 请输入已经在Amazon Bedrock 上创建好的Guardrail ID发布的版本, 通常使用版本号, 例如2.
|
||||
llm_description: Please enter the published version of the Guardrail ID that has already been created on Amazon Bedrock. This is typically a version number, such as 2.
|
||||
form: form
|
||||
- name: source
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Content Source (INPUT or OUTPUT)
|
||||
zh_Hans: 内容来源 (INPUT or OUTPUT)
|
||||
human_description:
|
||||
en_US: The source of data used in the request to apply the guardrail. Valid Values "INPUT | OUTPUT"
|
||||
zh_Hans: 用于应用护栏的请求中所使用的数据来源。有效值为 "INPUT | OUTPUT"
|
||||
llm_description: The source of data used in the request to apply the guardrail. Valid Values "INPUT | OUTPUT"
|
||||
form: form
|
||||
- name: text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Content to be reviewed
|
||||
zh_Hans: 待审查内容
|
||||
human_description:
|
||||
en_US: The content used for requesting guardrail review, which can be either user input or LLM output.
|
||||
zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。
|
||||
llm_description: The content used for requesting guardrail review, which can be either user input or LLM output.
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
|
||||
zh_Hans: 请输入 Bedrock 客户端的 AWS 区域,例如 'us-east-1'。
|
||||
llm_description: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
|
||||
form: form
|
||||
|
|
@ -1,91 +0,0 @@
|
|||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class LambdaTranslateUtilsTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name):
|
||||
msg = {
|
||||
"src_content": text_content,
|
||||
"src_lang": src_lang,
|
||||
"dest_lang": dest_lang,
|
||||
"dictionary_id": dictionary_name,
|
||||
"request_type": request_type,
|
||||
"model_id": model_id,
|
||||
}
|
||||
|
||||
invoke_response = self.lambda_client.invoke(
|
||||
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
|
||||
)
|
||||
response_body = invoke_response["Payload"]
|
||||
|
||||
response_str = response_body.read().decode("unicode_escape")
|
||||
|
||||
return response_str
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
line = 1
|
||||
text_content = tool_parameters.get("text_content", "")
|
||||
if not text_content:
|
||||
return self.create_text_message("Please input text_content")
|
||||
|
||||
line = 2
|
||||
src_lang = tool_parameters.get("src_lang", "")
|
||||
if not src_lang:
|
||||
return self.create_text_message("Please input src_lang")
|
||||
|
||||
line = 3
|
||||
dest_lang = tool_parameters.get("dest_lang", "")
|
||||
if not dest_lang:
|
||||
return self.create_text_message("Please input dest_lang")
|
||||
|
||||
line = 4
|
||||
lambda_name = tool_parameters.get("lambda_name", "")
|
||||
if not lambda_name:
|
||||
return self.create_text_message("Please input lambda_name")
|
||||
|
||||
line = 5
|
||||
request_type = tool_parameters.get("request_type", "")
|
||||
if not request_type:
|
||||
return self.create_text_message("Please input request_type")
|
||||
|
||||
line = 6
|
||||
model_id = tool_parameters.get("model_id", "")
|
||||
if not model_id:
|
||||
return self.create_text_message("Please input model_id")
|
||||
|
||||
line = 7
|
||||
dictionary_name = tool_parameters.get("dictionary_name", "")
|
||||
if not dictionary_name:
|
||||
return self.create_text_message("Please input dictionary_name")
|
||||
|
||||
result = self._invoke_lambda(
|
||||
text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name
|
||||
)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
identity:
|
||||
name: lambda_translate_utils
|
||||
author: AWS
|
||||
label:
|
||||
en_US: TranslateTool
|
||||
zh_Hans: 翻译工具
|
||||
pt_BR: TranslateTool
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
|
||||
zh_Hans: 大语言模型翻译工具(专词映射获取),需要在AWS上进行额外部署,可参考Github Repo - https://github.com/ybalbert001/dynamodb-rag
|
||||
pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
|
||||
llm: A util tools for translation.
|
||||
parameters:
|
||||
- name: text_content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: source content for translation
|
||||
zh_Hans: 待翻译原文
|
||||
pt_BR: source content for translation
|
||||
human_description:
|
||||
en_US: source content for translation
|
||||
zh_Hans: 待翻译原文
|
||||
pt_BR: source content for translation
|
||||
llm_description: source content for translation
|
||||
form: llm
|
||||
- name: src_lang
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: source language code
|
||||
zh_Hans: 原文语言代号
|
||||
pt_BR: source language code
|
||||
human_description:
|
||||
en_US: source language code
|
||||
zh_Hans: 原文语言代号
|
||||
pt_BR: source language code
|
||||
llm_description: source language code
|
||||
form: llm
|
||||
- name: dest_lang
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: target language code
|
||||
zh_Hans: 目标语言代号
|
||||
pt_BR: target language code
|
||||
human_description:
|
||||
en_US: target language code
|
||||
zh_Hans: 目标语言代号
|
||||
pt_BR: target language code
|
||||
llm_description: target language code
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of Lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of Lambda
|
||||
human_description:
|
||||
en_US: region of Lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of Lambda
|
||||
llm_description: region of Lambda
|
||||
form: form
|
||||
- name: model_id
|
||||
type: string
|
||||
required: false
|
||||
default: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
label:
|
||||
en_US: LLM model_id in bedrock
|
||||
zh_Hans: bedrock上的大语言模型model_id
|
||||
pt_BR: LLM model_id in bedrock
|
||||
human_description:
|
||||
en_US: LLM model_id in bedrock
|
||||
zh_Hans: bedrock上的大语言模型model_id
|
||||
pt_BR: LLM model_id in bedrock
|
||||
llm_description: LLM model_id in bedrock
|
||||
form: form
|
||||
- name: dictionary_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: dictionary name for term mapping
|
||||
zh_Hans: 专词映射表名称
|
||||
pt_BR: dictionary name for term mapping
|
||||
human_description:
|
||||
en_US: dictionary name for term mapping
|
||||
zh_Hans: 专词映射表名称
|
||||
pt_BR: dictionary name for term mapping
|
||||
llm_description: dictionary name for term mapping
|
||||
form: form
|
||||
- name: request_type
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: request type
|
||||
zh_Hans: 请求类型
|
||||
pt_BR: request type
|
||||
human_description:
|
||||
en_US: request type
|
||||
zh_Hans: 请求类型
|
||||
pt_BR: request type
|
||||
default: term_mapping
|
||||
options:
|
||||
- value: term_mapping
|
||||
label:
|
||||
en_US: term_mapping
|
||||
zh_Hans: 专词映射
|
||||
- value: segment_only
|
||||
label:
|
||||
en_US: segment_only
|
||||
zh_Hans: 仅切词
|
||||
- value: translate
|
||||
label:
|
||||
en_US: translate
|
||||
zh_Hans: 翻译内容
|
||||
form: form
|
||||
- name: lambda_name
|
||||
type: string
|
||||
default: "translate_tool"
|
||||
required: true
|
||||
label:
|
||||
en_US: AWS Lambda for term mapping retrieval
|
||||
zh_Hans: 专词召回映射 - AWS Lambda
|
||||
pt_BR: lambda name for term mapping retrieval
|
||||
human_description:
|
||||
en_US: AWS Lambda for term mapping retrieval
|
||||
zh_Hans: 专词召回映射 - AWS Lambda
|
||||
pt_BR: AWS Lambda for term mapping retrieval
|
||||
llm_description: AWS Lambda for term mapping retrieval
|
||||
form: form
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
class LambdaYamlToJsonTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
|
||||
msg = {"body": yaml_content}
|
||||
logger.info(json.dumps(msg))
|
||||
|
||||
invoke_response = self.lambda_client.invoke(
|
||||
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
|
||||
)
|
||||
response_body = invoke_response["Payload"]
|
||||
|
||||
response_str = response_body.read().decode("utf-8")
|
||||
resp_json = json.loads(response_str)
|
||||
|
||||
logger.info(resp_json)
|
||||
if resp_json["statusCode"] != 200:
|
||||
raise Exception(f"Invalid status code: {response_str}")
|
||||
|
||||
return resp_json["body"]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
yaml_content = tool_parameters.get("yaml_content", "")
|
||||
if not yaml_content:
|
||||
return self.create_text_message("Please input yaml_content")
|
||||
|
||||
lambda_name = tool_parameters.get("lambda_name", "")
|
||||
if not lambda_name:
|
||||
return self.create_text_message("Please input lambda_name")
|
||||
logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}")
|
||||
|
||||
result = self._invoke_lambda(lambda_name, yaml_content)
|
||||
logger.debug(result)
|
||||
|
||||
return self.create_text_message(result)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception: {str(e)}")
|
||||
|
||||
console_handler.flush()
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
identity:
|
||||
name: lambda_yaml_to_json
|
||||
author: AWS
|
||||
label:
|
||||
en_US: LambdaYamlToJson
|
||||
zh_Hans: LambdaYamlToJson
|
||||
pt_BR: LambdaYamlToJson
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool to convert yaml to json using AWS Lambda.
|
||||
zh_Hans: 将 YAML 转为 JSON 的工具(通过AWS Lambda)。
|
||||
pt_BR: A tool to convert yaml to json using AWS Lambda.
|
||||
llm: A tool to convert yaml to json.
|
||||
parameters:
|
||||
- name: yaml_content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: YAML content to convert for
|
||||
zh_Hans: YAML 内容
|
||||
pt_BR: YAML content to convert for
|
||||
human_description:
|
||||
en_US: YAML content to convert for
|
||||
zh_Hans: YAML 内容
|
||||
pt_BR: YAML content to convert for
|
||||
llm_description: YAML content to convert for
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of lambda
|
||||
human_description:
|
||||
en_US: region of lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of lambda
|
||||
llm_description: region of lambda
|
||||
form: form
|
||||
- name: lambda_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: name of lambda
|
||||
zh_Hans: Lambda 名称
|
||||
pt_BR: name of lambda
|
||||
human_description:
|
||||
en_US: name of lambda
|
||||
zh_Hans: Lambda 名称
|
||||
pt_BR: name of lambda
|
||||
form: form
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
import json
|
||||
import operator
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SageMakerReRankTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
topk: int = None
|
||||
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
|
||||
inputs = [query_input] * len(docs)
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=rerank_endpoint,
|
||||
Body=json.dumps({"inputs": inputs, "docs": docs}),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model["Body"].read().decode("utf8")
|
||||
json_obj = json.loads(json_str)
|
||||
scores = json_obj["scores"]
|
||||
return scores if isinstance(scores, list) else [scores]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
line = 1
|
||||
if not self.sagemaker_endpoint:
|
||||
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
||||
|
||||
line = 2
|
||||
if not self.topk:
|
||||
self.topk = tool_parameters.get("topk", 5)
|
||||
|
||||
line = 3
|
||||
query = tool_parameters.get("query", "")
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
line = 4
|
||||
candidate_texts = tool_parameters.get("candidate_texts")
|
||||
if not candidate_texts:
|
||||
return self.create_text_message("Please input candidate_texts")
|
||||
|
||||
line = 5
|
||||
candidate_docs = json.loads(candidate_texts)
|
||||
docs = [item.get("content") for item in candidate_docs]
|
||||
|
||||
line = 6
|
||||
scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint)
|
||||
|
||||
line = 7
|
||||
for idx in range(len(candidate_docs)):
|
||||
candidate_docs[idx]["score"] = scores[idx]
|
||||
|
||||
line = 8
|
||||
sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 9
|
||||
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
identity:
|
||||
name: sagemaker_text_rerank
|
||||
author: AWS
|
||||
label:
|
||||
en_US: SagemakerRerank
|
||||
zh_Hans: Sagemaker重排序
|
||||
pt_BR: SagemakerRerank
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for performing text similarity ranking. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: Sagemaker重排序工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本
|
||||
pt_BR: A tool for performing text similarity ranking.
|
||||
llm: A tool for performing text similarity ranking. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
parameters:
|
||||
- name: sagemaker_endpoint
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: sagemaker endpoint for reranking
|
||||
zh_Hans: 重排序的SageMaker 端点
|
||||
pt_BR: sagemaker endpoint for reranking
|
||||
human_description:
|
||||
en_US: sagemaker endpoint for reranking
|
||||
zh_Hans: 重排序的SageMaker 端点
|
||||
pt_BR: sagemaker endpoint for reranking
|
||||
llm_description: sagemaker endpoint for reranking
|
||||
form: form
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
human_description:
|
||||
en_US: key words for searching
|
||||
zh_Hans: 查询关键词
|
||||
pt_BR: key words for searching
|
||||
llm_description: key words for searching
|
||||
form: llm
|
||||
- name: candidate_texts
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: text candidates
|
||||
zh_Hans: 候选文本
|
||||
pt_BR: text candidates
|
||||
human_description:
|
||||
en_US: searched candidates by query
|
||||
zh_Hans: 查询文本搜到候选文本
|
||||
pt_BR: searched candidates by query
|
||||
llm_description: searched candidates by query
|
||||
form: llm
|
||||
- name: topk
|
||||
type: number
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Limit for results count
|
||||
zh_Hans: 返回个数限制
|
||||
pt_BR: Limit for results count
|
||||
human_description:
|
||||
en_US: Limit for results count
|
||||
zh_Hans: 返回个数限制
|
||||
pt_BR: Limit for results count
|
||||
min: 1
|
||||
max: 10
|
||||
default: 5
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
human_description:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
llm_description: region of sagemaker endpoint
|
||||
form: form
|
||||
|
|
@ -1,101 +0,0 @@
|
|||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class TTSModelType(Enum):
|
||||
PresetVoice = "PresetVoice"
|
||||
CloneVoice = "CloneVoice"
|
||||
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
|
||||
InstructVoice = "InstructVoice"
|
||||
|
||||
|
||||
class SageMakerTTSTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
s3_client: Any = None
|
||||
comprehend_client: Any = None
|
||||
|
||||
def _detect_lang_code(self, content: str, map_dict: dict = None):
|
||||
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}
|
||||
|
||||
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||
language_code = response["Languages"][0]["LanguageCode"]
|
||||
return map_dict.get(language_code, "<|zh|>")
|
||||
|
||||
def _build_tts_payload(
|
||||
self,
|
||||
model_type: str,
|
||||
content_text: str,
|
||||
model_role: str,
|
||||
prompt_text: str,
|
||||
prompt_audio: str,
|
||||
instruct_text: str,
|
||||
):
|
||||
if model_type == TTSModelType.PresetVoice.value and model_role:
|
||||
return {"tts_text": content_text, "role": model_role}
|
||||
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
|
||||
return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||
lang_tag = self._detect_lang_code(content_text)
|
||||
return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag}
|
||||
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||
return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text}
|
||||
|
||||
raise RuntimeError(f"Invalid params for {model_type}")
|
||||
|
||||
def _invoke_sagemaker(self, payload: dict, endpoint: str):
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model["Body"].read().decode("utf8")
|
||||
json_obj = json.loads(json_str)
|
||||
return json_obj
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
self.comprehend_client = boto3.client("comprehend", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
self.comprehend_client = boto3.client("comprehend")
|
||||
|
||||
if not self.sagemaker_endpoint:
|
||||
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
||||
|
||||
tts_text = tool_parameters.get("tts_text")
|
||||
tts_infer_type = tool_parameters.get("tts_infer_type")
|
||||
|
||||
voice = tool_parameters.get("voice")
|
||||
mock_voice_audio = tool_parameters.get("mock_voice_audio")
|
||||
mock_voice_text = tool_parameters.get("mock_voice_text")
|
||||
voice_instruct_prompt = tool_parameters.get("voice_instruct_prompt")
|
||||
payload = self._build_tts_payload(
|
||||
tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt
|
||||
)
|
||||
|
||||
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
|
||||
|
||||
return self.create_text_message(text=result["s3_presign_url"])
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}")
|
||||
|
|
@ -1,149 +0,0 @@
|
|||
identity:
|
||||
name: sagemaker_tts
|
||||
author: AWS
|
||||
label:
|
||||
en_US: SagemakerTTS
|
||||
zh_Hans: Sagemaker语音合成
|
||||
pt_BR: SagemakerTTS
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for Speech synthesis - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: Sagemaker语音合成工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本
|
||||
pt_BR: A tool for Speech synthesis.
|
||||
llm: A tool for Speech synthesis. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
parameters:
|
||||
- name: sagemaker_endpoint
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: sagemaker endpoint for tts
|
||||
zh_Hans: 语音生成的SageMaker端点
|
||||
pt_BR: sagemaker endpoint for tts
|
||||
human_description:
|
||||
en_US: sagemaker endpoint for tts
|
||||
zh_Hans: 语音生成的SageMaker端点
|
||||
pt_BR: sagemaker endpoint for tts
|
||||
llm_description: sagemaker endpoint for tts
|
||||
form: form
|
||||
- name: tts_text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: tts text
|
||||
zh_Hans: 语音合成原文
|
||||
pt_BR: tts text
|
||||
human_description:
|
||||
en_US: tts text
|
||||
zh_Hans: 语音合成原文
|
||||
pt_BR: tts text
|
||||
llm_description: tts text
|
||||
form: llm
|
||||
- name: tts_infer_type
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: tts infer type
|
||||
zh_Hans: 合成方式
|
||||
pt_BR: tts infer type
|
||||
human_description:
|
||||
en_US: tts infer type
|
||||
zh_Hans: 合成方式
|
||||
pt_BR: tts infer type
|
||||
llm_description: tts infer type
|
||||
options:
|
||||
- value: PresetVoice
|
||||
label:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
- value: CloneVoice
|
||||
label:
|
||||
en_US: clone voice
|
||||
zh_Hans: 克隆音色
|
||||
- value: CloneVoice_CrossLingual
|
||||
label:
|
||||
en_US: clone crossLingual voice
|
||||
zh_Hans: 克隆音色(跨语言)
|
||||
- value: InstructVoice
|
||||
label:
|
||||
en_US: instruct voice
|
||||
zh_Hans: 指令音色
|
||||
form: form
|
||||
- name: voice
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
pt_BR: preset voice
|
||||
human_description:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
pt_BR: preset voice
|
||||
llm_description: preset voice
|
||||
options:
|
||||
- value: 中文男
|
||||
label:
|
||||
en_US: zh-cn male
|
||||
zh_Hans: 中文男
|
||||
- value: 中文女
|
||||
label:
|
||||
en_US: zh-cn female
|
||||
zh_Hans: 中文女
|
||||
- value: 粤语女
|
||||
label:
|
||||
en_US: zh-TW female
|
||||
zh_Hans: 粤语女
|
||||
form: form
|
||||
- name: mock_voice_audio
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: clone voice link
|
||||
zh_Hans: 克隆音频链接
|
||||
pt_BR: clone voice link
|
||||
human_description:
|
||||
en_US: clone voice link
|
||||
zh_Hans: 克隆音频链接
|
||||
pt_BR: clone voice link
|
||||
llm_description: clone voice link
|
||||
form: llm
|
||||
- name: mock_voice_text
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: text of clone voice
|
||||
zh_Hans: 克隆音频对应文本
|
||||
pt_BR: text of clone voice
|
||||
human_description:
|
||||
en_US: text of clone voice
|
||||
zh_Hans: 克隆音频对应文本
|
||||
pt_BR: text of clone voice
|
||||
llm_description: text of clone voice
|
||||
form: llm
|
||||
- name: voice_instruct_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: instruct prompt for voice
|
||||
zh_Hans: 音色指令文本
|
||||
pt_BR: instruct prompt for voice
|
||||
human_description:
|
||||
en_US: instruct prompt for voice
|
||||
zh_Hans: 音色指令文本
|
||||
pt_BR: instruct prompt for voice
|
||||
llm_description: instruct prompt for voice
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
human_description:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
llm_description: region of sagemaker endpoint
|
||||
form: form
|
||||
|
Before Width: | Height: | Size: 50 KiB |
|
|
@ -1,20 +0,0 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.azuredalle.tools.dalle3 import DallE3Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AzureDALLEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
DallE3Tool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "square", "n": 1},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
identity:
|
||||
author: Leslie
|
||||
name: azuredalle
|
||||
label:
|
||||
en_US: Azure DALL-E
|
||||
zh_Hans: Azure DALL-E 绘画
|
||||
pt_BR: Azure DALL-E
|
||||
description:
|
||||
en_US: Azure DALL-E art
|
||||
zh_Hans: Azure DALL-E 绘画
|
||||
pt_BR: Azure DALL-E art
|
||||
icon: icon.png
|
||||
tags:
|
||||
- image
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
azure_openai_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API key
|
||||
zh_Hans: 密钥
|
||||
pt_BR: API key
|
||||
help:
|
||||
en_US: Please input your Azure OpenAI API key
|
||||
zh_Hans: 请输入你的 Azure OpenAI API key
|
||||
pt_BR: Introduza a sua chave de API OpenAI do Azure
|
||||
placeholder:
|
||||
en_US: Please input your Azure OpenAI API key
|
||||
zh_Hans: 请输入你的 Azure OpenAI API key
|
||||
pt_BR: Introduza a sua chave de API OpenAI do Azure
|
||||
azure_openai_api_model_name:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Deployment Name
|
||||
zh_Hans: 部署名称
|
||||
pt_BR: Nome da Implantação
|
||||
help:
|
||||
en_US: Please input the name of your Azure Openai DALL-E API deployment
|
||||
zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称
|
||||
pt_BR: Insira o nome da implantação da API DALL-E do Azure Openai
|
||||
placeholder:
|
||||
en_US: Please input the name of your Azure Openai DALL-E API deployment
|
||||
zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称
|
||||
pt_BR: Insira o nome da implantação da API DALL-E do Azure Openai
|
||||
azure_openai_base_url:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API Endpoint URL
|
||||
zh_Hans: API 域名
|
||||
pt_BR: API Endpoint URL
|
||||
help:
|
||||
en_US: Please input your Azure OpenAI Endpoint URL, e.g. https://xxx.openai.azure.com/
|
||||
zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/
|
||||
pt_BR: Introduza a URL do Azure OpenAI Endpoint, e.g. https://xxx.openai.azure.com/
|
||||
placeholder:
|
||||
en_US: Please input your Azure OpenAI Endpoint URL, e.g. https://xxx.openai.azure.com/
|
||||
zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/
|
||||
pt_BR: Introduza a URL do Azure OpenAI Endpoint, e.g. https://xxx.openai.azure.com/
|
||||
azure_openai_api_version:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API Version
|
||||
zh_Hans: API 版本
|
||||
pt_BR: API Version
|
||||
help:
|
||||
en_US: Please input your Azure OpenAI API Version,e.g. 2023-12-01-preview
|
||||
zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview
|
||||
pt_BR: Introduza a versão da API OpenAI do Azure,e.g. 2023-12-01-preview
|
||||
placeholder:
|
||||
en_US: Please input your Azure OpenAI API Version,e.g. 2023-12-01-preview
|
||||
zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview
|
||||
pt_BR: Introduza a versão da API OpenAI do Azure,e.g. 2023-12-01-preview
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
import random
|
||||
from base64 import b64decode
|
||||
from typing import Any, Union
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
client = AzureOpenAI(
|
||||
api_version=self.runtime.credentials["azure_openai_api_version"],
|
||||
azure_endpoint=self.runtime.credentials["azure_openai_base_url"],
|
||||
api_key=self.runtime.credentials["azure_openai_api_key"],
|
||||
)
|
||||
|
||||
SIZE_MAPPING = {
|
||||
"square": "1024x1024",
|
||||
"vertical": "1024x1792",
|
||||
"horizontal": "1792x1024",
|
||||
}
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
# get size
|
||||
size = SIZE_MAPPING[tool_parameters.get("size", "square")]
|
||||
# get n
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in {"standard", "hd"}:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in {"natural", "vivid"}:
|
||||
return self.create_text_message("Invalid style")
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
||||
extra_body = {"seed": seed_id}
|
||||
|
||||
# call openapi dalle3
|
||||
model = self.runtime.credentials["azure_openai_api_model_name"]
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
size=size,
|
||||
n=n,
|
||||
extra_body=extra_body,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format="b64_json",
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image.b64_json),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
random_id = "".join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
|
|
@ -1,136 +0,0 @@
|
|||
identity:
|
||||
name: azure_dalle3
|
||||
author: Leslie
|
||||
label:
|
||||
en_US: Azure DALL-E 3
|
||||
zh_Hans: Azure DALL-E 3 绘画
|
||||
pt_BR: Azure DALL-E 3
|
||||
description:
|
||||
en_US: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources
|
||||
zh_Hans: DALL-E 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像,相比于DallE 2, DallE 3拥有更强的绘画能力,但会消耗更多的资源
|
||||
pt_BR: DALL-E 3 é uma poderosa ferramenta de desenho que pode desenhar a imagem que você deseja com base em seu prompt, em comparação com DallE 2, DallE 3 tem uma capacidade de desenho mais forte, mas consumirá mais recursos
|
||||
description:
|
||||
human:
|
||||
en_US: DALL-E is a text to image tool
|
||||
zh_Hans: DALL-E 是一个文本到图像的工具
|
||||
pt_BR: DALL-E é uma ferramenta de texto para imagem
|
||||
llm: DALL-E is a tool used to generate images from text
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of DallE 3
|
||||
zh_Hans: 图像提示词,您可以查看 DallE 3 的官方文档
|
||||
pt_BR: Imagem prompt, você pode verificar a documentação oficial do DallE 3
|
||||
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
|
||||
form: llm
|
||||
- name: seed_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Seed ID
|
||||
zh_Hans: 种子ID
|
||||
pt_BR: ID da semente
|
||||
human_description:
|
||||
en_US: Image generation seed ID to ensure consistency of series generated images
|
||||
zh_Hans: 图像生成种子ID,确保系列生成图像的一致性
|
||||
pt_BR: ID de semente de geração de imagem para garantir a consistência das imagens geradas em série
|
||||
llm_description: If the user requests image consistency, extract the seed ID from the user's question or context.The seed id consists of an 8-bit string containing uppercase and lowercase letters and numbers
|
||||
form: llm
|
||||
- name: size
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image size
|
||||
zh_Hans: 选择图像大小
|
||||
pt_BR: seleccionar o tamanho da imagem
|
||||
label:
|
||||
en_US: Image size
|
||||
zh_Hans: 图像大小
|
||||
pt_BR: Tamanho da imagem
|
||||
form: form
|
||||
options:
|
||||
- value: square
|
||||
label:
|
||||
en_US: Squre(1024x1024)
|
||||
zh_Hans: 方(1024x1024)
|
||||
pt_BR: Squire(1024x1024)
|
||||
- value: vertical
|
||||
label:
|
||||
en_US: Vertical(1024x1792)
|
||||
zh_Hans: 竖屏(1024x1792)
|
||||
pt_BR: Vertical(1024x1792)
|
||||
- value: horizontal
|
||||
label:
|
||||
en_US: Horizontal(1792x1024)
|
||||
zh_Hans: 横屏(1792x1024)
|
||||
pt_BR: Horizontal(1792x1024)
|
||||
default: square
|
||||
- name: n
|
||||
type: number
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the number of images
|
||||
zh_Hans: 选择图像数量
|
||||
pt_BR: seleccionar o número de imagens
|
||||
label:
|
||||
en_US: Number of images
|
||||
zh_Hans: 图像数量
|
||||
pt_BR: Número de imagens
|
||||
form: form
|
||||
min: 1
|
||||
max: 1
|
||||
default: 1
|
||||
- name: quality
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image quality
|
||||
zh_Hans: 选择图像质量
|
||||
pt_BR: seleccionar a qualidade da imagem
|
||||
label:
|
||||
en_US: Image quality
|
||||
zh_Hans: 图像质量
|
||||
pt_BR: Qualidade da imagem
|
||||
form: form
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
pt_BR: Normal
|
||||
- value: hd
|
||||
label:
|
||||
en_US: HD
|
||||
zh_Hans: 高清
|
||||
pt_BR: HD
|
||||
default: standard
|
||||
- name: style
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image style
|
||||
zh_Hans: 选择图像风格
|
||||
pt_BR: seleccionar o estilo da imagem
|
||||
label:
|
||||
en_US: Image style
|
||||
zh_Hans: 图像风格
|
||||
pt_BR: Estilo da imagem
|
||||
form: form
|
||||
options:
|
||||
- value: vivid
|
||||
label:
|
||||
en_US: Vivid
|
||||
zh_Hans: 生动
|
||||
pt_BR: Vívido
|
||||
- value: natural
|
||||
label:
|
||||
en_US: Natural
|
||||
zh_Hans: 自然
|
||||
pt_BR: Natural
|
||||
default: vivid
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
<svg viewBox="-29.62167543756803 0.1 574.391675437568 799.8100000000002" xmlns="http://www.w3.org/2000/svg" width="1888"
|
||||
height="2500">
|
||||
<linearGradient id="a" gradientUnits="userSpaceOnUse" x1="286.383" x2="542.057" y1="284.169" y2="569.112">
|
||||
<stop offset="0" stop-color="#37bdff"/>
|
||||
<stop offset=".25" stop-color="#26c6f4"/>
|
||||
<stop offset=".5" stop-color="#15d0e9"/>
|
||||
<stop offset=".75" stop-color="#3bd6df"/>
|
||||
<stop offset="1" stop-color="#62dcd4"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="b" gradientUnits="userSpaceOnUse" x1="108.979" x2="100.756" y1="675.98" y2="43.669">
|
||||
<stop offset="0" stop-color="#1b48ef"/>
|
||||
<stop offset=".5" stop-color="#2080f1"/>
|
||||
<stop offset="1" stop-color="#26b8f4"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="c" gradientUnits="userSpaceOnUse" x1="256.823" x2="875.632" y1="649.719" y2="649.719">
|
||||
<stop offset="0" stop-color="#39d2ff"/>
|
||||
<stop offset=".5" stop-color="#248ffa"/>
|
||||
<stop offset="1" stop-color="#104cf5"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="d" gradientUnits="userSpaceOnUse" x1="256.823" x2="875.632" y1="649.719" y2="649.719">
|
||||
<stop offset="0" stop-color="#fff"/>
|
||||
<stop offset="1"/>
|
||||
</linearGradient>
|
||||
<path d="M249.97 277.48c-.12.96-.12 2.05-.12 3.12 0 4.16.83 8.16 2.33 11.84l1.34 2.76 5.3 13.56 27.53 70.23 24.01 61.33c6.85 12.38 17.82 22.1 31.05 27.28l4.11 1.51c.16.05.43.05.65.11l65.81 22.63v.05l25.16 8.64 1.72.58c.06 0 .16.06.22.06 4.96 1.25 9.82 2.93 14.46 4.98 10.73 4.63 20.46 11.23 28.77 19.28 3.35 3.2 6.43 6.65 9.28 10.33a88.64 88.64 0 0 1 6.64 9.72c8.78 14.58 13.82 31.72 13.82 49.97 0 3.26-.16 6.41-.49 9.61-.11 1.41-.28 2.77-.49 4.12v.11c-.22 1.43-.49 2.91-.76 4.36-.28 1.41-.54 2.81-.86 4.21-.05.16-.11.33-.17.49-.3 1.42-.68 2.82-1.07 4.23-.35 1.33-.79 2.7-1.28 3.99a42.96 42.96 0 0 1-1.51 4.16c-.49 1.4-1.07 2.82-1.72 4.16-1.78 4.11-3.9 8.06-6.28 11.83a97.889 97.889 0 0 1-10.47 13.95c30.88-33.2 51.41-76.07 56.52-123.51.86-7.78 1.3-15.67 1.3-23.61 0-5.07-.22-10.09-.55-15.13-3.89-56.89-29.79-107.77-69.32-144.08-10.9-10.09-22.81-19.07-35.62-26.69l-24.2-12.37-122.63-62.93a30.15 30.15 0 0 0-11.93-2.44c-15.88 0-28.99 12.11-30.55 27.56z"
|
||||
fill="#7f7f7f"/>
|
||||
<path d="M249.97 277.48c-.12.96-.12 2.05-.12 3.12 0 4.16.83 8.16 2.33 11.84l1.34 2.76 5.3 13.56 27.53 70.23 24.01 61.33c6.85 12.38 17.82 22.1 31.05 27.28l4.11 1.51c.16.05.43.05.65.11l65.81 22.63v.05l25.16 8.64 1.72.58c.06 0 .16.06.22.06 4.96 1.25 9.82 2.93 14.46 4.98 10.73 4.63 20.46 11.23 28.77 19.28 3.35 3.2 6.43 6.65 9.28 10.33a88.64 88.64 0 0 1 6.64 9.72c8.78 14.58 13.82 31.72 13.82 49.97 0 3.26-.16 6.41-.49 9.61-.11 1.41-.28 2.77-.49 4.12v.11c-.22 1.43-.49 2.91-.76 4.36-.28 1.41-.54 2.81-.86 4.21-.05.16-.11.33-.17.49-.3 1.42-.68 2.82-1.07 4.23-.35 1.33-.79 2.7-1.28 3.99a42.96 42.96 0 0 1-1.51 4.16c-.49 1.4-1.07 2.82-1.72 4.16-1.78 4.11-3.9 8.06-6.28 11.83a97.889 97.889 0 0 1-10.47 13.95c30.88-33.2 51.41-76.07 56.52-123.51.86-7.78 1.3-15.67 1.3-23.61 0-5.07-.22-10.09-.55-15.13-3.89-56.89-29.79-107.77-69.32-144.08-10.9-10.09-22.81-19.07-35.62-26.69l-24.2-12.37-122.63-62.93a30.15 30.15 0 0 0-11.93-2.44c-15.88 0-28.99 12.11-30.55 27.56z"
|
||||
fill="url(#a)"/>
|
||||
<path d="M31.62.1C14.17.41.16 14.69.16 32.15v559.06c.07 3.9.29 7.75.57 11.66.25 2.06.52 4.2.9 6.28 7.97 44.87 47.01 78.92 94.15 78.92 16.53 0 32.03-4.21 45.59-11.53.08-.06.22-.14.29-.14l4.88-2.95 19.78-11.64 25.16-14.93.06-496.73c0-33.01-16.52-62.11-41.81-79.4-.6-.36-1.18-.74-1.71-1.17L50.12 5.56C45.16 2.28 39.18.22 32.77.1z"
|
||||
fill="#7f7f7f"/>
|
||||
<path d="M31.62.1C14.17.41.16 14.69.16 32.15v559.06c.07 3.9.29 7.75.57 11.66.25 2.06.52 4.2.9 6.28 7.97 44.87 47.01 78.92 94.15 78.92 16.53 0 32.03-4.21 45.59-11.53.08-.06.22-.14.29-.14l4.88-2.95 19.78-11.64 25.16-14.93.06-496.73c0-33.01-16.52-62.11-41.81-79.4-.6-.36-1.18-.74-1.71-1.17L50.12 5.56C45.16 2.28 39.18.22 32.77.1z"
|
||||
fill="url(#b)"/>
|
||||
<path d="M419.81 510.84L194.72 644.26l-3.24 1.95v.71l-25.16 14.9-19.77 11.67-4.85 2.93-.33.16c-13.53 7.35-29.04 11.51-45.56 11.51-47.13 0-86.22-34.03-94.16-78.92 3.77 32.84 14.96 63.41 31.84 90.04 34.76 54.87 93.54 93.04 161.54 99.67h41.58c36.78-3.84 67.49-18.57 99.77-38.46l49.64-30.36c22.36-14.33 83.05-49.58 100.93-69.36 3.89-4.33 7.4-8.97 10.47-13.94 2.38-3.78 4.5-7.73 6.28-11.84.6-1.4 1.17-2.76 1.72-4.15.52-1.38 1.01-2.77 1.51-4.18.93-2.7 1.67-5.41 2.38-8.2.36-1.59.69-3.16 1.02-4.72 1.08-5.89 1.67-11.94 1.67-18.21 0-18.25-5.04-35.39-13.77-49.95-2-3.4-4.2-6.65-6.64-9.72-2.85-3.7-5.93-7.13-9.28-10.33-8.31-8.05-18.01-14.65-28.77-19.29-4.64-2.05-9.48-3.74-14.46-4.97-.06 0-.16-.06-.22-.06l-1.72-.58z"
|
||||
fill="#7f7f7f"/>
|
||||
<path d="M419.81 510.84L194.72 644.26l-3.24 1.95v.71l-25.16 14.9-19.77 11.67-4.85 2.93-.33.16c-13.53 7.35-29.04 11.51-45.56 11.51-47.13 0-86.22-34.03-94.16-78.92 3.77 32.84 14.96 63.41 31.84 90.04 34.76 54.87 93.54 93.04 161.54 99.67h41.58c36.78-3.84 67.49-18.57 99.77-38.46l49.64-30.36c22.36-14.33 83.05-49.58 100.93-69.36 3.89-4.33 7.4-8.97 10.47-13.94 2.38-3.78 4.5-7.73 6.28-11.84.6-1.4 1.17-2.76 1.72-4.15.52-1.38 1.01-2.77 1.51-4.18.93-2.7 1.67-5.41 2.38-8.2.36-1.59.69-3.16 1.02-4.72 1.08-5.89 1.67-11.94 1.67-18.21 0-18.25-5.04-35.39-13.77-49.95-2-3.4-4.2-6.65-6.64-9.72-2.85-3.7-5.93-7.13-9.28-10.33-8.31-8.05-18.01-14.65-28.77-19.29-4.64-2.05-9.48-3.74-14.46-4.97-.06 0-.16-.06-.22-.06l-1.72-.58z"
|
||||
fill="url(#c)"/>
|
||||
<path d="M512 595.46c0 6.27-.59 12.33-1.68 18.22-.32 1.56-.65 3.12-1.02 4.7-.7 2.8-1.44 5.51-2.37 8.22-.49 1.4-.99 2.8-1.51 4.16-.54 1.4-1.12 2.76-1.73 4.16a87.873 87.873 0 0 1-6.26 11.83 96.567 96.567 0 0 1-10.48 13.94c-17.88 19.79-78.57 55.04-100.93 69.37l-49.64 30.36c-36.39 22.42-70.77 38.29-114.13 39.38-2.05.06-4.06.11-6.05.11-2.8 0-5.56-.05-8.33-.16-73.42-2.8-137.45-42.25-174.38-100.54a213.368 213.368 0 0 1-31.84-90.04c7.94 44.89 47.03 78.92 94.16 78.92 16.52 0 32.03-4.17 45.56-11.51l.33-.17 4.85-2.92 19.77-11.67 25.16-14.9v-.71l3.24-1.95 225.09-133.43 17.33-10.27 1.72.58c.05 0 .16.06.22.06 4.98 1.23 9.83 2.92 14.46 4.97 10.76 4.64 20.45 11.24 28.77 19.29a92.13 92.13 0 0 1 9.28 10.33c2.44 3.07 4.64 6.32 6.64 9.72 8.73 14.56 13.77 31.7 13.77 49.95z"
|
||||
fill="#7f7f7f" opacity=".15"/>
|
||||
<path d="M512 595.46c0 6.27-.59 12.33-1.68 18.22-.32 1.56-.65 3.12-1.02 4.7-.7 2.8-1.44 5.51-2.37 8.22-.49 1.4-.99 2.8-1.51 4.16-.54 1.4-1.12 2.76-1.73 4.16a87.873 87.873 0 0 1-6.26 11.83 96.567 96.567 0 0 1-10.48 13.94c-17.88 19.79-78.57 55.04-100.93 69.37l-49.64 30.36c-36.39 22.42-70.77 38.29-114.13 39.38-2.05.06-4.06.11-6.05.11-2.8 0-5.56-.05-8.33-.16-73.42-2.8-137.45-42.25-174.38-100.54a213.368 213.368 0 0 1-31.84-90.04c7.94 44.89 47.03 78.92 94.16 78.92 16.52 0 32.03-4.17 45.56-11.51l.33-.17 4.85-2.92 19.77-11.67 25.16-14.9v-.71l3.24-1.95 225.09-133.43 17.33-10.27 1.72.58c.05 0 .16.06.22.06 4.98 1.23 9.83 2.92 14.46 4.97 10.76 4.64 20.45 11.24 28.77 19.29a92.13 92.13 0 0 1 9.28 10.33c2.44 3.07 4.64 6.32 6.64 9.72 8.73 14.56 13.77 31.7 13.77 49.95z"
|
||||
fill="url(#d)" opacity=".15"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 6.9 KiB |
|
|
@ -1,23 +0,0 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.bing.tools.bing_web_search import BingSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class BingProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
BingSearchTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).validate_credentials(
|
||||
credentials=credentials,
|
||||
tool_parameters={
|
||||
"query": "test",
|
||||
"result_type": "link",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
identity:
|
||||
author: Dify
|
||||
name: bing
|
||||
label:
|
||||
en_US: Bing
|
||||
zh_Hans: Bing
|
||||
pt_BR: Bing
|
||||
description:
|
||||
en_US: Bing Search
|
||||
zh_Hans: Bing 搜索
|
||||
pt_BR: Bing Search
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
subscription_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Bing subscription key
|
||||
zh_Hans: Bing subscription key
|
||||
pt_BR: Bing subscription key
|
||||
placeholder:
|
||||
en_US: Please input your Bing subscription key
|
||||
zh_Hans: 请输入你的 Bing subscription key
|
||||
pt_BR: Please input your Bing subscription key
|
||||
help:
|
||||
en_US: Get your Bing subscription key from Bing
|
||||
zh_Hans: 从 Bing 获取您的 Bing subscription key
|
||||
pt_BR: Get your Bing subscription key from Bing
|
||||
url: https://www.microsoft.com/cognitive-services/en-us/bing-web-search-api
|
||||
server_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: Bing endpoint
|
||||
zh_Hans: Bing endpoint
|
||||
pt_BR: Bing endpoint
|
||||
placeholder:
|
||||
en_US: Please input your Bing endpoint
|
||||
zh_Hans: 请输入你的 Bing 端点
|
||||
pt_BR: Please input your Bing endpoint
|
||||
help:
|
||||
en_US: An endpoint is like "https://api.bing.microsoft.com/v7.0/search"
|
||||
zh_Hans: 例如 "https://api.bing.microsoft.com/v7.0/search"
|
||||
pt_BR: An endpoint is like "https://api.bing.microsoft.com/v7.0/search"
|
||||
default: https://api.bing.microsoft.com/v7.0/search
|
||||
allow_entities:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Entities Search
|
||||
zh_Hans: 支持实体搜索
|
||||
pt_BR: Allow Entities Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow entity search
|
||||
zh_Hans: 您的订阅计划是否支持实体搜索
|
||||
pt_BR: Does your subscription plan allow entity search
|
||||
default: true
|
||||
allow_web_pages:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Web Pages Search
|
||||
zh_Hans: 支持网页搜索
|
||||
pt_BR: Allow Web Pages Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow web pages search
|
||||
zh_Hans: 您的订阅计划是否支持网页搜索
|
||||
pt_BR: Does your subscription plan allow web pages search
|
||||
default: true
|
||||
allow_computation:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Computation Search
|
||||
zh_Hans: 支持计算搜索
|
||||
pt_BR: Allow Computation Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow computation search
|
||||
zh_Hans: 您的订阅计划是否支持计算搜索
|
||||
pt_BR: Does your subscription plan allow computation search
|
||||
default: false
|
||||
allow_news:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow News Search
|
||||
zh_Hans: 支持新闻搜索
|
||||
pt_BR: Allow News Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow news search
|
||||
zh_Hans: 您的订阅计划是否支持新闻搜索
|
||||
pt_BR: Does your subscription plan allow news search
|
||||
default: false
|
||||
allow_related_searches:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Related Searches
|
||||
zh_Hans: 支持相关搜索
|
||||
pt_BR: Allow Related Searches
|
||||
help:
|
||||
en_US: Does your subscription plan allow related searches
|
||||
zh_Hans: 您的订阅计划是否支持相关搜索
|
||||
pt_BR: Does your subscription plan allow related searches
|
||||
default: false
|
||||
|
|
@ -1,202 +0,0 @@
|
|||
from typing import Any, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from requests import get
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BingSearchTool(BuiltinTool):
|
||||
url: str = "https://api.bing.microsoft.com/v7.0/search"
|
||||
|
||||
def _invoke_bing(
|
||||
self,
|
||||
user_id: str,
|
||||
server_url: str,
|
||||
subscription_key: str,
|
||||
query: str,
|
||||
limit: int,
|
||||
result_type: str,
|
||||
market: str,
|
||||
lang: str,
|
||||
filters: list[str],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke bing search
|
||||
"""
|
||||
market_code = f"{lang}-{market}"
|
||||
accept_language = f"{lang},{market_code};q=0.9"
|
||||
headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language}
|
||||
|
||||
query = quote(query)
|
||||
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
|
||||
response = get(server_url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error {response.status_code}: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
search_results = response["webPages"]["value"][:limit] if "webPages" in response else []
|
||||
related_searches = response["relatedSearches"]["value"] if "relatedSearches" in response else []
|
||||
entities = response["entities"]["value"] if "entities" in response else []
|
||||
news = response["news"]["value"] if "news" in response else []
|
||||
computation = response["computation"]["value"] if "computation" in response else None
|
||||
|
||||
if result_type == "link":
|
||||
results = []
|
||||
if search_results:
|
||||
for result in search_results:
|
||||
url = f': {result["url"]}' if "url" in result else ""
|
||||
results.append(self.create_text_message(text=f'{result["name"]}{url}'))
|
||||
|
||||
if entities:
|
||||
for entity in entities:
|
||||
url = f': {entity["url"]}' if "url" in entity else ""
|
||||
results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}'))
|
||||
|
||||
if news:
|
||||
for news_item in news:
|
||||
url = f': {news_item["url"]}' if "url" in news_item else ""
|
||||
results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}'))
|
||||
|
||||
if related_searches:
|
||||
for related in related_searches:
|
||||
url = f': {related["displayText"]}' if "displayText" in related else ""
|
||||
results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}'))
|
||||
|
||||
return results
|
||||
else:
|
||||
# construct text
|
||||
text = ""
|
||||
if search_results:
|
||||
for i, result in enumerate(search_results):
|
||||
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
||||
|
||||
if computation and "expression" in computation and "value" in computation:
|
||||
text += "\nComputation:\n"
|
||||
text += f'{computation["expression"]} = {computation["value"]}\n'
|
||||
|
||||
if entities:
|
||||
text += "\nEntities:\n"
|
||||
for entity in entities:
|
||||
url = f'- {entity["url"]}' if "url" in entity else ""
|
||||
text += f'{entity.get("name", "")}{url}\n'
|
||||
|
||||
if news:
|
||||
text += "\nNews:\n"
|
||||
for news_item in news:
|
||||
url = f'- {news_item["url"]}' if "url" in news_item else ""
|
||||
text += f'{news_item.get("name", "")}{url}\n'
|
||||
|
||||
if related_searches:
|
||||
text += "\n\nRelated Searches:\n"
|
||||
for related in related_searches:
|
||||
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else ""
|
||||
text += f'{related.get("displayText", "")}{url}\n'
|
||||
|
||||
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None:
|
||||
key = credentials.get("subscription_key")
|
||||
if not key:
|
||||
raise Exception("subscription_key is required")
|
||||
|
||||
server_url = credentials.get("server_url")
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get("query")
|
||||
if not query:
|
||||
raise Exception("query is required")
|
||||
|
||||
limit = min(tool_parameters.get("limit", 5), 10)
|
||||
result_type = tool_parameters.get("result_type", "text") or "text"
|
||||
|
||||
market = tool_parameters.get("market", "US")
|
||||
lang = tool_parameters.get("language", "en")
|
||||
filter = []
|
||||
|
||||
if credentials.get("allow_entities", False):
|
||||
filter.append("Entities")
|
||||
|
||||
if credentials.get("allow_computation", False):
|
||||
filter.append("Computation")
|
||||
|
||||
if credentials.get("allow_news", False):
|
||||
filter.append("News")
|
||||
|
||||
if credentials.get("allow_related_searches", False):
|
||||
filter.append("RelatedSearches")
|
||||
|
||||
if credentials.get("allow_web_pages", False):
|
||||
filter.append("WebPages")
|
||||
|
||||
if not filter:
|
||||
raise Exception("At least one filter is required")
|
||||
|
||||
self._invoke_bing(
|
||||
user_id="test",
|
||||
server_url=server_url,
|
||||
subscription_key=key,
|
||||
query=query,
|
||||
limit=limit,
|
||||
result_type=result_type,
|
||||
market=market,
|
||||
lang=lang,
|
||||
filters=filter,
|
||||
)
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
|
||||
key = self.runtime.credentials.get("subscription_key", None)
|
||||
if not key:
|
||||
raise Exception("subscription_key is required")
|
||||
|
||||
server_url = self.runtime.credentials.get("server_url", None)
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get("query")
|
||||
if not query:
|
||||
raise Exception("query is required")
|
||||
|
||||
limit = min(tool_parameters.get("limit", 5), 10)
|
||||
result_type = tool_parameters.get("result_type", "text") or "text"
|
||||
|
||||
market = tool_parameters.get("market", "US")
|
||||
lang = tool_parameters.get("language", "en")
|
||||
filter = []
|
||||
|
||||
if tool_parameters.get("enable_computation", False):
|
||||
filter.append("Computation")
|
||||
if tool_parameters.get("enable_entities", False):
|
||||
filter.append("Entities")
|
||||
if tool_parameters.get("enable_news", False):
|
||||
filter.append("News")
|
||||
if tool_parameters.get("enable_related_search", False):
|
||||
filter.append("RelatedSearches")
|
||||
if tool_parameters.get("enable_webpages", False):
|
||||
filter.append("WebPages")
|
||||
|
||||
if not filter:
|
||||
raise Exception("At least one filter is required")
|
||||
|
||||
return self._invoke_bing(
|
||||
user_id=user_id,
|
||||
server_url=server_url,
|
||||
subscription_key=key,
|
||||
query=query,
|
||||
limit=limit,
|
||||
result_type=result_type,
|
||||
market=market,
|
||||
lang=lang,
|
||||
filters=filter,
|
||||
)
|
||||
|
|
@ -1,584 +0,0 @@
|
|||
identity:
|
||||
name: bing_web_search
|
||||
author: Dify
|
||||
label:
|
||||
en_US: BingWebSearch
|
||||
zh_Hans: 必应网页搜索
|
||||
pt_BR: BingWebSearch
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for performing a Bing SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
zh_Hans: 一个用于执行 Bing SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
|
||||
pt_BR: A tool for performing a Bing SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
llm: A tool for performing a Bing SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
human_description:
|
||||
en_US: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
pt_BR: used for searching
|
||||
llm_description: key words for searching
|
||||
- name: enable_computation
|
||||
type: boolean
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Enable computation
|
||||
zh_Hans: 启用计算
|
||||
pt_BR: Enable computation
|
||||
human_description:
|
||||
en_US: enable computation
|
||||
zh_Hans: 启用计算
|
||||
pt_BR: enable computation
|
||||
default: false
|
||||
- name: enable_entities
|
||||
type: boolean
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Enable entities
|
||||
zh_Hans: 启用实体搜索
|
||||
pt_BR: Enable entities
|
||||
human_description:
|
||||
en_US: enable entities
|
||||
zh_Hans: 启用实体搜索
|
||||
pt_BR: enable entities
|
||||
default: true
|
||||
- name: enable_news
|
||||
type: boolean
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Enable news
|
||||
zh_Hans: 启用新闻搜索
|
||||
pt_BR: Enable news
|
||||
human_description:
|
||||
en_US: enable news
|
||||
zh_Hans: 启用新闻搜索
|
||||
pt_BR: enable news
|
||||
default: false
|
||||
- name: enable_related_search
|
||||
type: boolean
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Enable related search
|
||||
zh_Hans: 启用相关搜索
|
||||
pt_BR: Enable related search
|
||||
human_description:
|
||||
en_US: enable related search
|
||||
zh_Hans: 启用相关搜索
|
||||
pt_BR: enable related search
|
||||
default: false
|
||||
- name: enable_webpages
|
||||
type: boolean
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Enable webpages search
|
||||
zh_Hans: 启用网页搜索
|
||||
pt_BR: Enable webpages search
|
||||
human_description:
|
||||
en_US: enable webpages search
|
||||
zh_Hans: 启用网页搜索
|
||||
pt_BR: enable webpages search
|
||||
default: true
|
||||
- name: limit
|
||||
type: number
|
||||
required: true
|
||||
form: form
|
||||
label:
|
||||
en_US: Limit for results length
|
||||
zh_Hans: 返回长度限制
|
||||
pt_BR: Limit for results length
|
||||
human_description:
|
||||
en_US: limit the number of results
|
||||
zh_Hans: 限制返回结果的数量
|
||||
pt_BR: limit the number of results
|
||||
min: 1
|
||||
max: 10
|
||||
default: 5
|
||||
- name: result_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: result type
|
||||
zh_Hans: 结果类型
|
||||
pt_BR: result type
|
||||
human_description:
|
||||
en_US: return a list of links or texts
|
||||
zh_Hans: 返回一个连接列表还是纯文本内容
|
||||
pt_BR: return a list of links or texts
|
||||
default: text
|
||||
options:
|
||||
- value: link
|
||||
label:
|
||||
en_US: Link
|
||||
zh_Hans: 链接
|
||||
pt_BR: Link
|
||||
- value: text
|
||||
label:
|
||||
en_US: Text
|
||||
zh_Hans: 文本
|
||||
pt_BR: Text
|
||||
form: form
|
||||
- name: market
|
||||
type: select
|
||||
label:
|
||||
en_US: Market
|
||||
zh_Hans: 市场
|
||||
pt_BR: Market
|
||||
human_description:
|
||||
en_US: market takes responsibility for the region
|
||||
zh_Hans: 市场决定了搜索结果的地区
|
||||
pt_BR: market takes responsibility for the region
|
||||
required: false
|
||||
form: form
|
||||
default: US
|
||||
options:
|
||||
- value: AR
|
||||
label:
|
||||
en_US: Argentina
|
||||
zh_Hans: 阿根廷
|
||||
pt_BR: Argentina
|
||||
- value: AU
|
||||
label:
|
||||
en_US: Australia
|
||||
zh_Hans: 澳大利亚
|
||||
pt_BR: Australia
|
||||
- value: AT
|
||||
label:
|
||||
en_US: Austria
|
||||
zh_Hans: 奥地利
|
||||
pt_BR: Austria
|
||||
- value: BE
|
||||
label:
|
||||
en_US: Belgium
|
||||
zh_Hans: 比利时
|
||||
pt_BR: Belgium
|
||||
- value: BR
|
||||
label:
|
||||
en_US: Brazil
|
||||
zh_Hans: 巴西
|
||||
pt_BR: Brazil
|
||||
- value: CA
|
||||
label:
|
||||
en_US: Canada
|
||||
zh_Hans: 加拿大
|
||||
pt_BR: Canada
|
||||
- value: CL
|
||||
label:
|
||||
en_US: Chile
|
||||
zh_Hans: 智利
|
||||
pt_BR: Chile
|
||||
- value: CO
|
||||
label:
|
||||
en_US: Colombia
|
||||
zh_Hans: 哥伦比亚
|
||||
pt_BR: Colombia
|
||||
- value: CN
|
||||
label:
|
||||
en_US: China
|
||||
zh_Hans: 中国
|
||||
pt_BR: China
|
||||
- value: CZ
|
||||
label:
|
||||
en_US: Czech Republic
|
||||
zh_Hans: 捷克共和国
|
||||
pt_BR: Czech Republic
|
||||
- value: DK
|
||||
label:
|
||||
en_US: Denmark
|
||||
zh_Hans: 丹麦
|
||||
pt_BR: Denmark
|
||||
- value: FI
|
||||
label:
|
||||
en_US: Finland
|
||||
zh_Hans: 芬兰
|
||||
pt_BR: Finland
|
||||
- value: FR
|
||||
label:
|
||||
en_US: France
|
||||
zh_Hans: 法国
|
||||
pt_BR: France
|
||||
- value: DE
|
||||
label:
|
||||
en_US: Germany
|
||||
zh_Hans: 德国
|
||||
pt_BR: Germany
|
||||
- value: HK
|
||||
label:
|
||||
en_US: Hong Kong
|
||||
zh_Hans: 香港
|
||||
pt_BR: Hong Kong
|
||||
- value: IN
|
||||
label:
|
||||
en_US: India
|
||||
zh_Hans: 印度
|
||||
pt_BR: India
|
||||
- value: ID
|
||||
label:
|
||||
en_US: Indonesia
|
||||
zh_Hans: 印度尼西亚
|
||||
pt_BR: Indonesia
|
||||
- value: IT
|
||||
label:
|
||||
en_US: Italy
|
||||
zh_Hans: 意大利
|
||||
pt_BR: Italy
|
||||
- value: JP
|
||||
label:
|
||||
en_US: Japan
|
||||
zh_Hans: 日本
|
||||
pt_BR: Japan
|
||||
- value: KR
|
||||
label:
|
||||
en_US: Korea
|
||||
zh_Hans: 韩国
|
||||
pt_BR: Korea
|
||||
- value: MY
|
||||
label:
|
||||
en_US: Malaysia
|
||||
zh_Hans: 马来西亚
|
||||
pt_BR: Malaysia
|
||||
- value: MX
|
||||
label:
|
||||
en_US: Mexico
|
||||
zh_Hans: 墨西哥
|
||||
pt_BR: Mexico
|
||||
- value: NL
|
||||
label:
|
||||
en_US: Netherlands
|
||||
zh_Hans: 荷兰
|
||||
pt_BR: Netherlands
|
||||
- value: NZ
|
||||
label:
|
||||
en_US: New Zealand
|
||||
zh_Hans: 新西兰
|
||||
pt_BR: New Zealand
|
||||
- value: 'NO'
|
||||
label:
|
||||
en_US: Norway
|
||||
zh_Hans: 挪威
|
||||
pt_BR: Norway
|
||||
- value: PH
|
||||
label:
|
||||
en_US: Philippines
|
||||
zh_Hans: 菲律宾
|
||||
pt_BR: Philippines
|
||||
- value: PL
|
||||
label:
|
||||
en_US: Poland
|
||||
zh_Hans: 波兰
|
||||
pt_BR: Poland
|
||||
- value: PT
|
||||
label:
|
||||
en_US: Portugal
|
||||
zh_Hans: 葡萄牙
|
||||
pt_BR: Portugal
|
||||
- value: RU
|
||||
label:
|
||||
en_US: Russia
|
||||
zh_Hans: 俄罗斯
|
||||
pt_BR: Russia
|
||||
- value: SA
|
||||
label:
|
||||
en_US: Saudi Arabia
|
||||
zh_Hans: 沙特阿拉伯
|
||||
pt_BR: Saudi Arabia
|
||||
- value: SG
|
||||
label:
|
||||
en_US: Singapore
|
||||
zh_Hans: 新加坡
|
||||
pt_BR: Singapore
|
||||
- value: ZA
|
||||
label:
|
||||
en_US: South Africa
|
||||
zh_Hans: 南非
|
||||
pt_BR: South Africa
|
||||
- value: ES
|
||||
label:
|
||||
en_US: Spain
|
||||
zh_Hans: 西班牙
|
||||
pt_BR: Spain
|
||||
- value: SE
|
||||
label:
|
||||
en_US: Sweden
|
||||
zh_Hans: 瑞典
|
||||
pt_BR: Sweden
|
||||
- value: CH
|
||||
label:
|
||||
en_US: Switzerland
|
||||
zh_Hans: 瑞士
|
||||
pt_BR: Switzerland
|
||||
- value: TW
|
||||
label:
|
||||
en_US: Taiwan
|
||||
zh_Hans: 台湾
|
||||
pt_BR: Taiwan
|
||||
- value: TH
|
||||
label:
|
||||
en_US: Thailand
|
||||
zh_Hans: 泰国
|
||||
pt_BR: Thailand
|
||||
- value: TR
|
||||
label:
|
||||
en_US: Turkey
|
||||
zh_Hans: 土耳其
|
||||
pt_BR: Turkey
|
||||
- value: GB
|
||||
label:
|
||||
en_US: United Kingdom
|
||||
zh_Hans: 英国
|
||||
pt_BR: United Kingdom
|
||||
- value: US
|
||||
label:
|
||||
en_US: United States
|
||||
zh_Hans: 美国
|
||||
pt_BR: United States
|
||||
- name: language
|
||||
type: select
|
||||
label:
|
||||
en_US: Language
|
||||
zh_Hans: 语言
|
||||
pt_BR: Language
|
||||
human_description:
|
||||
en_US: language takes responsibility for the language of the search result
|
||||
zh_Hans: 语言决定了搜索结果的语言
|
||||
pt_BR: language takes responsibility for the language of the search result
|
||||
required: false
|
||||
default: en
|
||||
form: form
|
||||
options:
|
||||
- value: ar
|
||||
label:
|
||||
en_US: Arabic
|
||||
zh_Hans: 阿拉伯语
|
||||
pt_BR: Arabic
|
||||
- value: bg
|
||||
label:
|
||||
en_US: Bulgarian
|
||||
zh_Hans: 保加利亚语
|
||||
pt_BR: Bulgarian
|
||||
- value: ca
|
||||
label:
|
||||
en_US: Catalan
|
||||
zh_Hans: 加泰罗尼亚语
|
||||
pt_BR: Catalan
|
||||
- value: zh-hans
|
||||
label:
|
||||
en_US: Chinese (Simplified)
|
||||
zh_Hans: 中文(简体)
|
||||
pt_BR: Chinese (Simplified)
|
||||
- value: zh-hant
|
||||
label:
|
||||
en_US: Chinese (Traditional)
|
||||
zh_Hans: 中文(繁体)
|
||||
pt_BR: Chinese (Traditional)
|
||||
- value: cs
|
||||
label:
|
||||
en_US: Czech
|
||||
zh_Hans: 捷克语
|
||||
pt_BR: Czech
|
||||
- value: da
|
||||
label:
|
||||
en_US: Danish
|
||||
zh_Hans: 丹麦语
|
||||
pt_BR: Danish
|
||||
- value: nl
|
||||
label:
|
||||
en_US: Dutch
|
||||
zh_Hans: 荷兰语
|
||||
pt_BR: Dutch
|
||||
- value: en
|
||||
label:
|
||||
en_US: English
|
||||
zh_Hans: 英语
|
||||
pt_BR: English
|
||||
- value: et
|
||||
label:
|
||||
en_US: Estonian
|
||||
zh_Hans: 爱沙尼亚语
|
||||
pt_BR: Estonian
|
||||
- value: fi
|
||||
label:
|
||||
en_US: Finnish
|
||||
zh_Hans: 芬兰语
|
||||
pt_BR: Finnish
|
||||
- value: fr
|
||||
label:
|
||||
en_US: French
|
||||
zh_Hans: 法语
|
||||
pt_BR: French
|
||||
- value: de
|
||||
label:
|
||||
en_US: German
|
||||
zh_Hans: 德语
|
||||
pt_BR: German
|
||||
- value: el
|
||||
label:
|
||||
en_US: Greek
|
||||
zh_Hans: 希腊语
|
||||
pt_BR: Greek
|
||||
- value: he
|
||||
label:
|
||||
en_US: Hebrew
|
||||
zh_Hans: 希伯来语
|
||||
pt_BR: Hebrew
|
||||
- value: hi
|
||||
label:
|
||||
en_US: Hindi
|
||||
zh_Hans: 印地语
|
||||
pt_BR: Hindi
|
||||
- value: hu
|
||||
label:
|
||||
en_US: Hungarian
|
||||
zh_Hans: 匈牙利语
|
||||
pt_BR: Hungarian
|
||||
- value: id
|
||||
label:
|
||||
en_US: Indonesian
|
||||
zh_Hans: 印尼语
|
||||
pt_BR: Indonesian
|
||||
- value: it
|
||||
label:
|
||||
en_US: Italian
|
||||
zh_Hans: 意大利语
|
||||
pt_BR: Italian
|
||||
- value: jp
|
||||
label:
|
||||
en_US: Japanese
|
||||
zh_Hans: 日语
|
||||
pt_BR: Japanese
|
||||
- value: kn
|
||||
label:
|
||||
en_US: Kannada
|
||||
zh_Hans: 卡纳达语
|
||||
pt_BR: Kannada
|
||||
- value: ko
|
||||
label:
|
||||
en_US: Korean
|
||||
zh_Hans: 韩语
|
||||
pt_BR: Korean
|
||||
- value: lv
|
||||
label:
|
||||
en_US: Latvian
|
||||
zh_Hans: 拉脱维亚语
|
||||
pt_BR: Latvian
|
||||
- value: lt
|
||||
label:
|
||||
en_US: Lithuanian
|
||||
zh_Hans: 立陶宛语
|
||||
pt_BR: Lithuanian
|
||||
- value: ms
|
||||
label:
|
||||
en_US: Malay
|
||||
zh_Hans: 马来语
|
||||
pt_BR: Malay
|
||||
- value: ml
|
||||
label:
|
||||
en_US: Malayalam
|
||||
zh_Hans: 马拉雅拉姆语
|
||||
pt_BR: Malayalam
|
||||
- value: mr
|
||||
label:
|
||||
en_US: Marathi
|
||||
zh_Hans: 马拉地语
|
||||
pt_BR: Marathi
|
||||
- value: nb
|
||||
label:
|
||||
en_US: Norwegian
|
||||
zh_Hans: 挪威语
|
||||
pt_BR: Norwegian
|
||||
- value: pl
|
||||
label:
|
||||
en_US: Polish
|
||||
zh_Hans: 波兰语
|
||||
pt_BR: Polish
|
||||
- value: pt-br
|
||||
label:
|
||||
en_US: Portuguese (Brazil)
|
||||
zh_Hans: 葡萄牙语(巴西)
|
||||
pt_BR: Portuguese (Brazil)
|
||||
- value: pt-pt
|
||||
label:
|
||||
en_US: Portuguese (Portugal)
|
||||
zh_Hans: 葡萄牙语(葡萄牙)
|
||||
pt_BR: Portuguese (Portugal)
|
||||
- value: pa
|
||||
label:
|
||||
en_US: Punjabi
|
||||
zh_Hans: 旁遮普语
|
||||
pt_BR: Punjabi
|
||||
- value: ro
|
||||
label:
|
||||
en_US: Romanian
|
||||
zh_Hans: 罗马尼亚语
|
||||
pt_BR: Romanian
|
||||
- value: ru
|
||||
label:
|
||||
en_US: Russian
|
||||
zh_Hans: 俄语
|
||||
pt_BR: Russian
|
||||
- value: sr
|
||||
label:
|
||||
en_US: Serbian
|
||||
zh_Hans: 塞尔维亚语
|
||||
pt_BR: Serbian
|
||||
- value: sk
|
||||
label:
|
||||
en_US: Slovak
|
||||
zh_Hans: 斯洛伐克语
|
||||
pt_BR: Slovak
|
||||
- value: sl
|
||||
label:
|
||||
en_US: Slovenian
|
||||
zh_Hans: 斯洛文尼亚语
|
||||
pt_BR: Slovenian
|
||||
- value: es
|
||||
label:
|
||||
en_US: Spanish
|
||||
zh_Hans: 西班牙语
|
||||
pt_BR: Spanish
|
||||
- value: sv
|
||||
label:
|
||||
en_US: Swedish
|
||||
zh_Hans: 瑞典语
|
||||
pt_BR: Swedish
|
||||
- value: ta
|
||||
label:
|
||||
en_US: Tamil
|
||||
zh_Hans: 泰米尔语
|
||||
pt_BR: Tamil
|
||||
- value: te
|
||||
label:
|
||||
en_US: Telugu
|
||||
zh_Hans: 泰卢固语
|
||||
pt_BR: Telugu
|
||||
- value: th
|
||||
label:
|
||||
en_US: Thai
|
||||
zh_Hans: 泰语
|
||||
pt_BR: Thai
|
||||
- value: tr
|
||||
label:
|
||||
en_US: Turkish
|
||||
zh_Hans: 土耳其语
|
||||
pt_BR: Turkish
|
||||
- value: uk
|
||||
label:
|
||||
en_US: Ukrainian
|
||||
zh_Hans: 乌克兰语
|
||||
pt_BR: Ukrainian
|
||||
- value: vi
|
||||
label:
|
||||
en_US: Vietnamese
|
||||
zh_Hans: 越南语
|
||||
pt_BR: Vietnamese
|
||||
|
|
@ -1 +0,0 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" width="48px" height="48px" clip-rule="evenodd" baseProfile="basic"><linearGradient id="yG17B1EwMCiUUe9ON9hI5a" x1="-329.441" x2="-329.276" y1="-136.877" y2="-136.877" gradientTransform="matrix(217.6 0 0 -255.4727 71694.719 -34944.293)" gradientUnits="userSpaceOnUse"><stop offset="0" stop-color="#e68e00"/><stop offset=".437" stop-color="#d75500"/><stop offset=".562" stop-color="#cf3600"/><stop offset=".89" stop-color="#d22900"/><stop offset="1" stop-color="#d42400"/></linearGradient><path fill="url(#yG17B1EwMCiUUe9ON9hI5a)" fill-rule="evenodd" d="M40.635,13.075l0.984-2.418c0,0-1.252-1.343-2.772-2.865 s-4.74-0.627-4.74-0.627L30.439,3H24h-6.439l-3.667,4.165c0,0-3.22-0.895-4.74,0.627s-2.772,2.865-2.772,2.865l0.984,2.418 l-1.252,3.582c0,0,3.682,13.965,4.114,15.671c0.85,3.358,1.431,4.656,3.846,6.358c2.415,1.701,6.797,4.656,7.512,5.104 C22.301,44.237,23.195,45,24,45c0.805,0,1.699-0.763,2.415-1.21c0.715-0.448,5.098-3.403,7.512-5.104 c2.415-1.701,2.996-3,3.846-6.358c0.431-1.705,4.114-15.671,4.114-15.671L40.635,13.075z" clip-rule="evenodd"/><linearGradient id="yG17B1EwMCiUUe9ON9hI5b" x1="19.087" x2="31.755" y1="7.685" y2="32.547" gradientUnits="userSpaceOnUse"><stop offset="0" stop-color="#fff"/><stop offset=".24" stop-color="#f8f8f7"/><stop offset="1" stop-color="#e3e3e1"/></linearGradient><path fill="url(#yG17B1EwMCiUUe9ON9hI5b)" fill-rule="evenodd" d="M33.078,9.807c0,0,4.716,5.709,4.716,6.929 s-0.593,1.542-1.19,2.176c-0.597,0.634-3.202,3.404-3.536,3.76c-0.335,0.356-1.031,0.895-0.621,1.866 c0.41,0.971,1.014,2.206,0.342,3.459c-0.672,1.253-1.824,2.089-2.561,1.951c-0.738-0.138-2.471-1.045-3.108-1.459 c-0.637-0.414-2.657-2.082-2.657-2.72c0-0.638,2.088-1.784,2.473-2.044c0.386-0.26,2.145-1.268,2.181-1.663 c0.036-0.396,0.022-0.511-0.497-1.489c-0.519-0.977-1.454-2.281-1.298-3.149c0.156-0.868,1.663-1.319,2.74-1.726 c1.076-0.407,3.148-1.175,3.406-1.295c0.259-0.12,0.192-0.233-0.592-0.308c-0.784-0.074-3.009-0.37-4.012-0.09 c-1.003,0.28-2.717,0.706-2.855,0.932c-0.139,0.226-0.261,0.233-0.119,1.012c0.142,0.779,0.876,4.517,0.948,5.181 c0.071,0.664,0.211,1.103-0.504,1.267c-0.715,0.164-1.919,0.448-2.332,0.448s-1.617-0.284-2.332-0.448 c-0.715-0.164-0.576-0.603-0.504-1.267s0.805-4.402,0.948-5.181c0.142-0.779,0.02-0.787-0.119-1.012 c-0.139-0.226-1.852-0.652-2.855-0.932c-1.003-0.28-3.228,0.016-4.012,0.09c-0.784,0.074-0.851,0.188-0.592,0.308 c0.259,0.119,2.331,0.888,3.406,1.295c1.076,0.407,2.584,0.858,2.74,1.726c0.156,0.868-0.779,2.172-1.298,3.149 c-0.519,0.977-0.533,1.093-0.497,1.489c0.036,0.395,1.795,1.403,2.181,1.663c0.386,0.26,2.473,1.406,2.473,2.044 c0,0.638-2.02,2.306-2.657,2.72c-0.637,0.414-2.37,1.321-3.108,1.459c-0.738,0.138-1.889-0.698-2.561-1.951 c-0.672-1.253-0.068-2.488,0.342-3.459c0.41-0.971-0.287-1.51-0.621-1.866c-0.334-0.356-2.94-3.126-3.536-3.76 c-0.597-0.634-1.19-0.956-1.19-2.176s4.716-6.929,4.716-6.929s3.98,0.761,4.516,0.761c0.537,0,1.699-0.448,2.772-0.806 C23.285,9.404,24,9.401,24,9.401s0.715,0.003,1.789,0.361c1.073,0.358,2.236,0.806,2.772,0.806 C29.098,10.568,33.078,9.807,33.078,9.807z M29.542,31.643c0.292,0.183,0.114,0.528-0.152,0.716 c-0.266,0.188-3.84,2.959-4.187,3.265c-0.347,0.306-0.857,0.812-1.203,0.812c-0.347,0-0.856-0.506-1.203-0.812 c-0.347-0.306-3.921-3.077-4.187-3.265c-0.266-0.188-0.444-0.533-0.152-0.716c0.292-0.183,1.205-0.645,2.466-1.298 c1.26-0.653,2.831-1.208,3.076-1.208c0.245,0,1.816,0.555,3.076,1.208C28.336,30.999,29.25,31.46,29.542,31.643z" clip-rule="evenodd"/><linearGradient id="yG17B1EwMCiUUe9ON9hI5c" x1="-329.279" x2="-329.074" y1="-140.492" y2="-140.492" gradientTransform="matrix(180.608 0 0 -46.0337 59468.86 -6460.583)" gradientUnits="userSpaceOnUse"><stop offset="0" stop-color="#e68e00"/><stop offset="1" stop-color="#d42400"/></linearGradient><path fill="url(#yG17B1EwMCiUUe9ON9hI5c)" fill-rule="evenodd" d="M34.106,7.165L30.439,3H24h-6.439l-3.667,4.165 c0,0-3.22-0.895-4.74,0.627c0,0,4.293-0.388,5.769,2.015c0,0,3.98,0.761,4.516,0.761c0.537,0,1.699-0.448,2.772-0.806 C23.285,9.404,24,9.401,24,9.401s0.715,0.003,1.789,0.361c1.073,0.358,2.236,0.806,2.772,0.806c0.537,0,4.516-0.761,4.516-0.761 c1.476-2.403,5.769-2.015,5.769-2.015C37.326,6.27,34.106,7.165,34.106,7.165" clip-rule="evenodd"/></svg>
|
||||
|
Before Width: | Height: | Size: 4.1 KiB |
|
|
@ -1,22 +0,0 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.brave.tools.brave_search import BraveSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class BraveProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
BraveSearchTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"query": "Sachin Tendulkar",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
identity:
|
||||
author: Yash Parmar
|
||||
name: brave
|
||||
label:
|
||||
en_US: Brave
|
||||
zh_Hans: Brave
|
||||
pt_BR: Brave
|
||||
description:
|
||||
en_US: Brave
|
||||
zh_Hans: Brave
|
||||
pt_BR: Brave
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
brave_search_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Brave Search API key
|
||||
zh_Hans: Brave Search API key
|
||||
pt_BR: Brave Search API key
|
||||
placeholder:
|
||||
en_US: Please input your Brave Search API key
|
||||
zh_Hans: 请输入你的 Brave Search API key
|
||||
pt_BR: Please input your Brave Search API key
|
||||
help:
|
||||
en_US: Get your Brave Search API key from Brave
|
||||
zh_Hans: 从 Brave 获取您的 Brave Search API key
|
||||
pt_BR: Get your Brave Search API key from Brave
|
||||
url: https://brave.com/search/api/
|
||||
base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: Brave server's Base URL
|
||||
zh_Hans: Brave服务器的API URL
|
||||
placeholder:
|
||||
en_US: https://api.search.brave.com/res/v1/web/search
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
BRAVE_BASE_URL = "https://api.search.brave.com/res/v1/web/search"
|
||||
|
||||
|
||||
class BraveSearchWrapper(BaseModel):
|
||||
"""Wrapper around the Brave search engine."""
|
||||
|
||||
api_key: str
|
||||
"""The API key to use for the Brave search engine."""
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the search request."""
|
||||
base_url: str = BRAVE_BASE_URL
|
||||
"""The base URL for the Brave search engine."""
|
||||
ensure_ascii: bool = True
|
||||
"""Ensure the JSON output is ASCII encoded."""
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Query the Brave search engine and return the results as a JSON string.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns: The results as a JSON string.
|
||||
|
||||
"""
|
||||
web_search_results = self._search_request(query=query)
|
||||
final_results = [
|
||||
{
|
||||
"title": item.get("title"),
|
||||
"link": item.get("url"),
|
||||
"snippet": item.get("description"),
|
||||
}
|
||||
for item in web_search_results
|
||||
]
|
||||
return json.dumps(final_results, ensure_ascii=self.ensure_ascii)
|
||||
|
||||
def _search_request(self, query: str) -> list[dict]:
|
||||
headers = {
|
||||
"X-Subscription-Token": self.api_key,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
req = requests.PreparedRequest()
|
||||
params = {**self.search_kwargs, **{"q": query}}
|
||||
req.prepare_url(self.base_url, params)
|
||||
if req.url is None:
|
||||
raise ValueError("prepared url is None, this should not happen")
|
||||
|
||||
response = requests.get(req.url, headers=headers)
|
||||
if not response.ok:
|
||||
raise Exception(f"HTTP error {response.status_code}")
|
||||
|
||||
return response.json().get("web", {}).get("results", [])
|
||||
|
||||
|
||||
class BraveSearch(BaseModel):
|
||||
"""Tool that queries the BraveSearch."""
|
||||
|
||||
name: str = "brave_search"
|
||||
description: str = (
|
||||
"a search engine. "
|
||||
"useful for when you need to answer questions about current events."
|
||||
" input should be a search query."
|
||||
)
|
||||
search_wrapper: BraveSearchWrapper
|
||||
|
||||
@classmethod
|
||||
def from_api_key(
|
||||
cls, api_key: str, base_url: str, search_kwargs: Optional[dict] = None, ensure_ascii: bool = True, **kwargs: Any
|
||||
) -> "BraveSearch":
|
||||
"""Create a tool from an api key.
|
||||
|
||||
Args:
|
||||
api_key: The api key to use.
|
||||
search_kwargs: Any additional kwargs to pass to the search wrapper.
|
||||
**kwargs: Any additional kwargs to pass to the tool.
|
||||
|
||||
Returns:
|
||||
A tool.
|
||||
"""
|
||||
wrapper = BraveSearchWrapper(
|
||||
api_key=api_key, base_url=base_url, search_kwargs=search_kwargs or {}, ensure_ascii=ensure_ascii
|
||||
)
|
||||
return cls(search_wrapper=wrapper, **kwargs)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
return self.search_wrapper.run(query)
|
||||
|
||||
|
||||
class BraveSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a search using Brave search engine.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invoke the Brave search tool.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool invocation.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation.
|
||||
"""
|
||||
query = tool_parameters.get("query", "")
|
||||
count = tool_parameters.get("count", 3)
|
||||
api_key = self.runtime.credentials["brave_search_api_key"]
|
||||
base_url = self.runtime.credentials.get("base_url", BRAVE_BASE_URL)
|
||||
ensure_ascii = tool_parameters.get("ensure_ascii", True)
|
||||
|
||||
if len(base_url) == 0:
|
||||
base_url = BRAVE_BASE_URL
|
||||
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
tool = BraveSearch.from_api_key(
|
||||
api_key=api_key, base_url=base_url, search_kwargs={"count": count}, ensure_ascii=ensure_ascii
|
||||
)
|
||||
|
||||
results = tool._run(query)
|
||||
|
||||
if not results:
|
||||
return self.create_text_message(f"No results found for '{query}' in Tavily")
|
||||
else:
|
||||
return self.create_text_message(text=results)
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
identity:
|
||||
name: brave_search
|
||||
author: Yash Parmar
|
||||
label:
|
||||
en_US: BraveSearch
|
||||
zh_Hans: BraveSearch
|
||||
pt_BR: BraveSearch
|
||||
description:
|
||||
human:
|
||||
en_US: BraveSearch is a privacy-focused search engine that leverages its own index to deliver unbiased, independent, and fast search results. It's designed to respect user privacy by not tracking searches or personal information, making it a secure choice for those concerned about online privacy.
|
||||
zh_Hans: BraveSearch 是一个注重隐私的搜索引擎,它利用自己的索引来提供公正、独立和快速的搜索结果。它旨在通过不跟踪搜索或个人信息来尊重用户隐私,为那些关注在线隐私的用户提供了一个安全的选择。
|
||||
pt_BR: BraveSearch é um mecanismo de busca focado na privacidade que utiliza seu próprio índice para entregar resultados de busca imparciais, independentes e rápidos. Ele é projetado para respeitar a privacidade do usuário, não rastreando buscas ou informações pessoais, tornando-se uma escolha segura para aqueles preocupados com a privacidade online.
|
||||
llm: BraveSearch is a privacy-centric search engine utilizing its unique index to offer unbiased, independent, and swift search results. It aims to protect user privacy by avoiding the tracking of search activities or personal data, presenting a secure option for users mindful of their online privacy.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
human_description:
|
||||
en_US: The text input used for initiating searches on the web, focusing on delivering relevant and accurate results without compromising user privacy.
|
||||
zh_Hans: 用于在网上启动搜索的文本输入,专注于提供相关且准确的结果,同时不妨碍用户隐私。
|
||||
pt_BR: A entrada de texto usada para iniciar pesquisas na web, focada em entregar resultados relevantes e precisos sem comprometer a privacidade do usuário.
|
||||
llm_description: Keywords or phrases entered to perform searches, aimed at providing relevant and precise results while ensuring the privacy of the user is maintained.
|
||||
form: llm
|
||||
- name: count
|
||||
type: number
|
||||
required: false
|
||||
default: 3
|
||||
label:
|
||||
en_US: Result count
|
||||
zh_Hans: 结果数量
|
||||
pt_BR: Contagem de resultados
|
||||
human_description:
|
||||
en_US: The number of search results to return, allowing users to control the breadth of their search output.
|
||||
zh_Hans: 要返回的搜索结果数量,允许用户控制他们搜索输出的广度。
|
||||
pt_BR: O número de resultados de pesquisa a serem retornados, permitindo que os usuários controlem a amplitude de sua saída de pesquisa.
|
||||
llm_description: Specifies the amount of search results to be displayed, offering users the ability to adjust the scope of their search findings.
|
||||
form: llm
|
||||
- name: ensure_ascii
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: Ensure ASCII
|
||||
zh_Hans: 确保 ASCII
|
||||
pt_BR: Ensure ASCII
|
||||
human_description:
|
||||
en_US: Ensure the JSON output is ASCII encoded
|
||||
zh_Hans: 确保输出的 JSON 是 ASCII 编码
|
||||
pt_BR: Ensure the JSON output is ASCII encoded
|
||||
form: form
|
||||
|
Before Width: | Height: | Size: 1.3 KiB |
|
|
@ -1,77 +0,0 @@
|
|||
import matplotlib.pyplot as plt
|
||||
from fontTools.ttLib import TTFont
|
||||
from matplotlib.font_manager import findSystemFonts
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.chart.tools.line import LinearChartTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
# use a business theme
|
||||
plt.style.use("seaborn-v0_8-darkgrid")
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
|
||||
def init_fonts():
|
||||
fonts = findSystemFonts()
|
||||
|
||||
popular_unicode_fonts = [
|
||||
"Arial Unicode MS",
|
||||
"DejaVu Sans",
|
||||
"DejaVu Sans Mono",
|
||||
"DejaVu Serif",
|
||||
"FreeMono",
|
||||
"FreeSans",
|
||||
"FreeSerif",
|
||||
"Liberation Mono",
|
||||
"Liberation Sans",
|
||||
"Liberation Serif",
|
||||
"Noto Mono",
|
||||
"Noto Sans",
|
||||
"Noto Serif",
|
||||
"Open Sans",
|
||||
"Roboto",
|
||||
"Source Code Pro",
|
||||
"Source Sans Pro",
|
||||
"Source Serif Pro",
|
||||
"Ubuntu",
|
||||
"Ubuntu Mono",
|
||||
]
|
||||
|
||||
supported_fonts = []
|
||||
|
||||
for font_path in fonts:
|
||||
try:
|
||||
font = TTFont(font_path)
|
||||
# get family name
|
||||
family_name = font["name"].getName(1, 3, 1).toUnicode()
|
||||
if family_name in popular_unicode_fonts:
|
||||
supported_fonts.append(family_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
plt.rcParams["font.family"] = "sans-serif"
|
||||
# sort by order of popular_unicode_fonts
|
||||
for font in popular_unicode_fonts:
|
||||
if font in supported_fonts:
|
||||
plt.rcParams["font.sans-serif"] = font
|
||||
break
|
||||
|
||||
|
||||
init_fonts()
|
||||
|
||||
|
||||
class ChartProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
LinearChartTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"data": "1,3,5,7,9,2,4,6,8,10",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
identity:
|
||||
author: Dify
|
||||
name: chart
|
||||
label:
|
||||
en_US: ChartGenerator
|
||||
zh_Hans: 图表生成
|
||||
pt_BR: Gerador de gráficos
|
||||
description:
|
||||
en_US: Chart Generator is a tool for generating statistical charts like bar chart, line chart, pie chart, etc.
|
||||
zh_Hans: 图表生成是一个用于生成可视化图表的工具,你可以通过它来生成柱状图、折线图、饼图等各类图表
|
||||
pt_BR: O Gerador de gráficos é uma ferramenta para gerar gráficos estatísticos como gráfico de barras, gráfico de linhas, gráfico de pizza, etc.
|
||||
icon: icon.png
|
||||
tags:
|
||||
- design
|
||||
- productivity
|
||||
- utilities
|
||||
credentials_for_provider:
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
import io
|
||||
from typing import Any, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BarChartTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get("data", "")
|
||||
if not data:
|
||||
return self.create_text_message("Please input data")
|
||||
data = data.split(";")
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all(i.isdigit() for i in data):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
axis = tool_parameters.get("x_axis") or None
|
||||
if axis:
|
||||
axis = axis.split(";")
|
||||
if len(axis) != len(data):
|
||||
axis = None
|
||||
|
||||
flg, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
if axis:
|
||||
axis = [label[:10] + "..." if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha="right")
|
||||
ax.bar(axis, data)
|
||||
else:
|
||||
ax.bar(range(len(data)), data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message("the bar chart is saved as an image."),
|
||||
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
|
||||
]
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
identity:
|
||||
name: bar_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Bar Chart
|
||||
zh_Hans: 柱状图
|
||||
pt_BR: Gráfico de barras
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Bar chart
|
||||
zh_Hans: 柱状图
|
||||
pt_BR: Gráfico de barras
|
||||
llm: generate a bar chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
pt_BR: dados
|
||||
human_description:
|
||||
en_US: data for generating chart, each number should be separated by ";"
|
||||
zh_Hans: 用于生成柱状图的数据,每个数字之间用 ";" 分隔
|
||||
pt_BR: dados para gerar gráfico de barras, cada número deve ser separado por ";"
|
||||
llm_description: data for generating bar chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: x_axis
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: X Axis
|
||||
zh_Hans: x 轴
|
||||
pt_BR: Eixo X
|
||||
human_description:
|
||||
en_US: X axis for chart, each text should be separated by ";"
|
||||
zh_Hans: 柱状图的 x 轴,每个文本之间用 ";" 分隔
|
||||
pt_BR: Eixo X para gráfico de barras, cada texto deve ser separado por ";"
|
||||
llm_description: x axis for bar chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
|
||||
form: llm
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
import io
|
||||
from typing import Any, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class LinearChartTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get("data", "")
|
||||
if not data:
|
||||
return self.create_text_message("Please input data")
|
||||
data = data.split(";")
|
||||
|
||||
axis = tool_parameters.get("x_axis") or None
|
||||
if axis:
|
||||
axis = axis.split(";")
|
||||
if len(axis) != len(data):
|
||||
axis = None
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all(i.isdigit() for i in data):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
flg, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
if axis:
|
||||
axis = [label[:10] + "..." if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha="right")
|
||||
ax.plot(axis, data)
|
||||
else:
|
||||
ax.plot(data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message("the linear chart is saved as an image."),
|
||||
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
|
||||
]
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
identity:
|
||||
name: line_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Linear Chart
|
||||
zh_Hans: 线性图表
|
||||
pt_BR: Gráfico linear
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: linear chart
|
||||
zh_Hans: 线性图表
|
||||
pt_BR: Gráfico linear
|
||||
llm: generate a linear chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
pt_BR: dados
|
||||
human_description:
|
||||
en_US: data for generating chart, each number should be separated by ";"
|
||||
zh_Hans: 用于生成线性图表的数据,每个数字之间用 ";" 分隔
|
||||
pt_BR: dados para gerar gráfico linear, cada número deve ser separado por ";"
|
||||
llm_description: data for generating linear chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: x_axis
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: X Axis
|
||||
zh_Hans: x 轴
|
||||
pt_BR: Eixo X
|
||||
human_description:
|
||||
en_US: X axis for chart, each text should be separated by ";"
|
||||
zh_Hans: 线性图表的 x 轴,每个文本之间用 ";" 分隔
|
||||
pt_BR: Eixo X para gráfico linear, cada texto deve ser separado por ";"
|
||||
llm_description: x axis for linear chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
|
||||
form: llm
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
import io
|
||||
from typing import Any, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class PieChartTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get("data", "")
|
||||
if not data:
|
||||
return self.create_text_message("Please input data")
|
||||
data = data.split(";")
|
||||
categories = tool_parameters.get("categories") or None
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all(i.isdigit() for i in data):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
flg, ax = plt.subplots()
|
||||
|
||||
if categories:
|
||||
categories = categories.split(";")
|
||||
if len(categories) != len(data):
|
||||
categories = None
|
||||
|
||||
if categories:
|
||||
ax.pie(data, labels=categories)
|
||||
else:
|
||||
ax.pie(data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message("the pie chart is saved as an image."),
|
||||
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
|
||||
]
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
identity:
|
||||
name: pie_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Pie Chart
|
||||
zh_Hans: 饼图
|
||||
pt_BR: Gráfico de pizza
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Pie chart
|
||||
zh_Hans: 饼图
|
||||
pt_BR: Gráfico de pizza
|
||||
llm: generate a pie chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
pt_BR: dados
|
||||
human_description:
|
||||
en_US: data for generating chart, each number should be separated by ";"
|
||||
zh_Hans: 用于生成饼图的数据,每个数字之间用 ";" 分隔
|
||||
pt_BR: dados para gerar gráfico de pizza, cada número deve ser separado por ";"
|
||||
llm_description: data for generating pie chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: categories
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Categories
|
||||
zh_Hans: 分类
|
||||
pt_BR: Categorias
|
||||
human_description:
|
||||
en_US: Categories for chart, each category should be separated by ";"
|
||||
zh_Hans: 饼图的分类,每个分类之间用 ";" 分隔
|
||||
pt_BR: Categorias para gráfico de pizza, cada categoria deve ser separada por ";"
|
||||
llm_description: categories for pie chart, categories should be a string contains a list of texts like "a;b;c;1;2" in order to match the data, each category should be split by ";"
|
||||
form: llm
|
||||
|
Before Width: | Height: | Size: 22 KiB |
|
|
@ -1,28 +0,0 @@
|
|||
"""Provide the input parameters type for the cogview provider class"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.cogview.tools.cogview3 import CogView3Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class COGVIEWProvider(BuiltinToolProviderController):
|
||||
"""cogview provider"""
|
||||
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
CogView3Tool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。",
|
||||
"size": "square",
|
||||
"n": 1,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e)) from e
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
identity:
|
||||
author: Waffle
|
||||
name: cogview
|
||||
label:
|
||||
en_US: CogView
|
||||
zh_Hans: CogView 绘画
|
||||
pt_BR: CogView
|
||||
description:
|
||||
en_US: CogView art
|
||||
zh_Hans: CogView 绘画
|
||||
pt_BR: CogView art
|
||||
icon: icon.png
|
||||
tags:
|
||||
- image
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
zhipuai_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: ZhipuAI API key
|
||||
zh_Hans: ZhipuAI API key
|
||||
pt_BR: ZhipuAI API key
|
||||
help:
|
||||
en_US: Please input your ZhipuAI API key
|
||||
zh_Hans: 请输入你的 ZhipuAI API key
|
||||
pt_BR: Please input your ZhipuAI API key
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI API key
|
||||
zh_Hans: 请输入你的 ZhipuAI API key
|
||||
pt_BR: Please input your ZhipuAI API key
|
||||
zhipuai_organizaion_id:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: ZhipuAI organization ID
|
||||
zh_Hans: ZhipuAI organization ID
|
||||
pt_BR: ZhipuAI organization ID
|
||||
help:
|
||||
en_US: Please input your ZhipuAI organization ID
|
||||
zh_Hans: 请输入你的 ZhipuAI organization ID
|
||||
pt_BR: Please input your ZhipuAI organization ID
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI organization ID
|
||||
zh_Hans: 请输入你的 ZhipuAI organization ID
|
||||
pt_BR: Please input your ZhipuAI organization ID
|
||||
zhipuai_base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: ZhipuAI base URL
|
||||
zh_Hans: ZhipuAI base URL
|
||||
pt_BR: ZhipuAI base URL
|
||||
help:
|
||||
en_US: Please input your ZhipuAI base URL
|
||||
zh_Hans: 请输入你的 ZhipuAI base URL
|
||||
pt_BR: Please input your ZhipuAI base URL
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI base URL
|
||||
zh_Hans: 请输入你的 ZhipuAI base URL
|
||||
pt_BR: Please input your ZhipuAI base URL
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
import random
|
||||
from typing import Any, Union
|
||||
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CogView3Tool(BuiltinTool):
|
||||
"""CogView3 Tool"""
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke CogView3 tool
|
||||
"""
|
||||
client = ZhipuAI(
|
||||
base_url=self.runtime.credentials["zhipuai_base_url"],
|
||||
api_key=self.runtime.credentials["zhipuai_api_key"],
|
||||
)
|
||||
size_mapping = {
|
||||
"square": "1024x1024",
|
||||
"vertical": "1024x1792",
|
||||
"horizontal": "1792x1024",
|
||||
}
|
||||
# prompt
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
# get size
|
||||
size = size_mapping[tool_parameters.get("size", "square")]
|
||||
# get n
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in {"standard", "hd"}:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in {"natural", "vivid"}:
|
||||
return self.create_text_message("Invalid style")
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
||||
extra_body = {"seed": seed_id}
|
||||
response = client.images.generations(
|
||||
prompt=prompt,
|
||||
model="cogview-3",
|
||||
size=size,
|
||||
n=n,
|
||||
extra_body=extra_body,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format="b64_json",
|
||||
)
|
||||
result = []
|
||||
for image in response.data:
|
||||
result.append(self.create_image_message(image=image.url))
|
||||
result.append(
|
||||
self.create_json_message(
|
||||
{
|
||||
"url": image.url,
|
||||
}
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
random_id = "".join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
identity:
|
||||
name: cogview3
|
||||
author: Waffle
|
||||
label:
|
||||
en_US: CogView 3
|
||||
zh_Hans: CogView 3 绘画
|
||||
pt_BR: CogView 3
|
||||
description:
|
||||
en_US: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt
|
||||
zh_Hans: CogView 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像
|
||||
pt_BR: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt
|
||||
description:
|
||||
human:
|
||||
en_US: CogView 3 is a text to image tool
|
||||
zh_Hans: CogView 3 是一个文本到图像的工具
|
||||
pt_BR: CogView 3 is a text to image tool
|
||||
llm: CogView 3 is a tool used to generate images from text
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of CogView 3
|
||||
zh_Hans: 图像提示词,您可以查看 CogView 3 的官方文档
|
||||
pt_BR: Image prompt, you can check the official documentation of CogView 3
|
||||
llm_description: Image prompt of CogView 3, you should describe the image you want to generate as a list of words as possible as detailed
|
||||
form: llm
|
||||
- name: size
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image size
|
||||
zh_Hans: 选择图像大小
|
||||
pt_BR: selecting the image size
|
||||
label:
|
||||
en_US: Image size
|
||||
zh_Hans: 图像大小
|
||||
pt_BR: Image size
|
||||
form: form
|
||||
options:
|
||||
- value: square
|
||||
label:
|
||||
en_US: Squre(1024x1024)
|
||||
zh_Hans: 方(1024x1024)
|
||||
pt_BR: Squre(1024x1024)
|
||||
- value: vertical
|
||||
label:
|
||||
en_US: Vertical(1024x1792)
|
||||
zh_Hans: 竖屏(1024x1792)
|
||||
pt_BR: Vertical(1024x1792)
|
||||
- value: horizontal
|
||||
label:
|
||||
en_US: Horizontal(1792x1024)
|
||||
zh_Hans: 横屏(1792x1024)
|
||||
pt_BR: Horizontal(1792x1024)
|
||||
default: square
|
||||
- name: n
|
||||
type: number
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the number of images
|
||||
zh_Hans: 选择图像数量
|
||||
pt_BR: selecting the number of images
|
||||
label:
|
||||
en_US: Number of images
|
||||
zh_Hans: 图像数量
|
||||
pt_BR: Number of images
|
||||
form: form
|
||||
min: 1
|
||||
max: 1
|
||||
default: 1
|
||||
- name: quality
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image quality
|
||||
zh_Hans: 选择图像质量
|
||||
pt_BR: selecting the image quality
|
||||
label:
|
||||
en_US: Image quality
|
||||
zh_Hans: 图像质量
|
||||
pt_BR: Image quality
|
||||
form: form
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
pt_BR: Standard
|
||||
- value: hd
|
||||
label:
|
||||
en_US: HD
|
||||
zh_Hans: 高清
|
||||
pt_BR: HD
|
||||
default: standard
|
||||
- name: style
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image style
|
||||
zh_Hans: 选择图像风格
|
||||
pt_BR: selecting the image style
|
||||
label:
|
||||
en_US: Image style
|
||||
zh_Hans: 图像风格
|
||||
pt_BR: Image style
|
||||
form: form
|
||||
options:
|
||||
- value: vivid
|
||||
label:
|
||||
en_US: Vivid
|
||||
zh_Hans: 生动
|
||||
pt_BR: Vivid
|
||||
- value: natural
|
||||
label:
|
||||
en_US: Natural
|
||||
zh_Hans: 自然
|
||||
pt_BR: Natural
|
||||
default: vivid
|
||||
|
Before Width: | Height: | Size: 209 KiB |
|
|
@ -1,17 +0,0 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.comfyui.tools.comfyui_stable_diffusion import ComfyuiStableDiffusionTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class ComfyUIProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
ComfyuiStableDiffusionTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).validate_models()
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
identity:
|
||||
author: Qun
|
||||
name: comfyui
|
||||
label:
|
||||
en_US: ComfyUI
|
||||
zh_Hans: ComfyUI
|
||||
pt_BR: ComfyUI
|
||||
description:
|
||||
en_US: ComfyUI is a tool for generating images which can be deployed locally.
|
||||
zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。
|
||||
pt_BR: ComfyUI is a tool for generating images which can be deployed locally.
|
||||
icon: icon.png
|
||||
tags:
|
||||
- image
|
||||
credentials_for_provider:
|
||||
base_url:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_Hans: ComfyUI服务器的Base URL
|
||||
pt_BR: Base URL
|
||||
placeholder:
|
||||
en_US: Please input your ComfyUI server's Base URL
|
||||
zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL
|
||||
pt_BR: Please input your ComfyUI server's Base URL
|
||||
model:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Model with suffix
|
||||
zh_Hans: 模型, 需要带后缀
|
||||
pt_BR: Model with suffix
|
||||
placeholder:
|
||||
en_US: Please input your model
|
||||
zh_Hans: 请输入你的模型名称
|
||||
pt_BR: Please input your model
|
||||
help:
|
||||
en_US: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
|
||||
zh_Hans: ComfyUI服务器的模型名称, 比如 xxx.safetensors
|
||||
pt_BR: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
|
||||
url: https://docs.dify.ai/tutorials/tool-configuration/comfyui
|
||||
|
|
@ -1,475 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Union
|
||||
|
||||
import websocket
|
||||
from httpx import get, post
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
SD_TXT2IMG_OPTIONS = {}
|
||||
LORA_NODE = {
|
||||
"inputs": {"lora_name": "", "strength_model": 1, "strength_clip": 1, "model": ["11", 0], "clip": ["11", 1]},
|
||||
"class_type": "LoraLoader",
|
||||
"_meta": {"title": "Load LoRA"},
|
||||
}
|
||||
FluxGuidanceNode = {
|
||||
"inputs": {"guidance": 3.5, "conditioning": ["6", 0]},
|
||||
"class_type": "FluxGuidance",
|
||||
"_meta": {"title": "FluxGuidance"},
|
||||
}
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
SD15 = 1
|
||||
SDXL = 2
|
||||
SD3 = 3
|
||||
FLUX = 4
|
||||
|
||||
|
||||
class ComfyuiStableDiffusionTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
# base url
|
||||
base_url = self.runtime.credentials.get("base_url", "")
|
||||
if not base_url:
|
||||
return self.create_text_message("Please input base_url")
|
||||
|
||||
if tool_parameters.get("model"):
|
||||
self.runtime.credentials["model"] = tool_parameters["model"]
|
||||
|
||||
model = self.runtime.credentials.get("model", None)
|
||||
if not model:
|
||||
return self.create_text_message("Please input model")
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
|
||||
# get negative prompt
|
||||
negative_prompt = tool_parameters.get("negative_prompt", "")
|
||||
|
||||
# get size
|
||||
width = tool_parameters.get("width", 1024)
|
||||
height = tool_parameters.get("height", 1024)
|
||||
|
||||
# get steps
|
||||
steps = tool_parameters.get("steps", 1)
|
||||
|
||||
# get sampler_name
|
||||
sampler_name = tool_parameters.get("sampler_name", "euler")
|
||||
|
||||
# scheduler
|
||||
scheduler = tool_parameters.get("scheduler", "normal")
|
||||
|
||||
# get cfg
|
||||
cfg = tool_parameters.get("cfg", 7.0)
|
||||
|
||||
# get model type
|
||||
model_type = tool_parameters.get("model_type", ModelType.SD15.name)
|
||||
|
||||
# get lora
|
||||
# supports up to 3 loras
|
||||
lora_list = []
|
||||
lora_strength_list = []
|
||||
if tool_parameters.get("lora_1"):
|
||||
lora_list.append(tool_parameters["lora_1"])
|
||||
lora_strength_list.append(tool_parameters.get("lora_strength_1", 1))
|
||||
if tool_parameters.get("lora_2"):
|
||||
lora_list.append(tool_parameters["lora_2"])
|
||||
lora_strength_list.append(tool_parameters.get("lora_strength_2", 1))
|
||||
if tool_parameters.get("lora_3"):
|
||||
lora_list.append(tool_parameters["lora_3"])
|
||||
lora_strength_list.append(tool_parameters.get("lora_strength_3", 1))
|
||||
|
||||
return self.text2img(
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
sampler_name=sampler_name,
|
||||
scheduler=scheduler,
|
||||
cfg=cfg,
|
||||
lora_list=lora_list,
|
||||
lora_strength_list=lora_strength_list,
|
||||
)
|
||||
|
||||
def get_checkpoints(self) -> list[str]:
|
||||
"""
|
||||
get checkpoints
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get("base_url", None)
|
||||
if not base_url:
|
||||
return []
|
||||
api_url = str(URL(base_url) / "models" / "checkpoints")
|
||||
response = get(url=api_url, timeout=(2, 10))
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
else:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
def get_loras(self) -> list[str]:
|
||||
"""
|
||||
get loras
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get("base_url", None)
|
||||
if not base_url:
|
||||
return []
|
||||
api_url = str(URL(base_url) / "models" / "loras")
|
||||
response = get(url=api_url, timeout=(2, 10))
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
else:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
def get_sample_methods(self) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
get sample method
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get("base_url", None)
|
||||
if not base_url:
|
||||
return [], []
|
||||
api_url = str(URL(base_url) / "object_info" / "KSampler")
|
||||
response = get(url=api_url, timeout=(2, 10))
|
||||
if response.status_code != 200:
|
||||
return [], []
|
||||
else:
|
||||
data = response.json()["KSampler"]["input"]["required"]
|
||||
return data["sampler_name"][0], data["scheduler"][0]
|
||||
except Exception as e:
|
||||
return [], []
|
||||
|
||||
def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
validate models
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get("base_url", None)
|
||||
if not base_url:
|
||||
raise ToolProviderCredentialValidationError("Please input base_url")
|
||||
model = self.runtime.credentials.get("model", None)
|
||||
if not model:
|
||||
raise ToolProviderCredentialValidationError("Please input model")
|
||||
|
||||
api_url = str(URL(base_url) / "models" / "checkpoints")
|
||||
response = get(url=api_url, timeout=(2, 10))
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError("Failed to get models")
|
||||
else:
|
||||
models = response.json()
|
||||
if len([d for d in models if d == model]) > 0:
|
||||
return self.create_text_message(json.dumps(models))
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f"model {model} does not exist")
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(f"Failed to get models, {e}")
|
||||
|
||||
def get_history(self, base_url, prompt_id):
|
||||
"""
|
||||
get history
|
||||
"""
|
||||
url = str(URL(base_url) / "history")
|
||||
respond = get(url, params={"prompt_id": prompt_id}, timeout=(2, 10))
|
||||
return respond.json()
|
||||
|
||||
def download_image(self, base_url, filename, subfolder, folder_type):
|
||||
"""
|
||||
download image
|
||||
"""
|
||||
url = str(URL(base_url) / "view")
|
||||
response = get(url, params={"filename": filename, "subfolder": subfolder, "type": folder_type}, timeout=(2, 10))
|
||||
return response.content
|
||||
|
||||
def queue_prompt_image(self, base_url, client_id, prompt):
|
||||
"""
|
||||
send prompt task and rotate
|
||||
"""
|
||||
# initiate task execution
|
||||
url = str(URL(base_url) / "prompt")
|
||||
respond = post(url, data=json.dumps({"client_id": client_id, "prompt": prompt}), timeout=(2, 10))
|
||||
prompt_id = respond.json()["prompt_id"]
|
||||
|
||||
ws = websocket.WebSocket()
|
||||
if "https" in base_url:
|
||||
ws_url = base_url.replace("https", "ws")
|
||||
else:
|
||||
ws_url = base_url.replace("http", "ws")
|
||||
ws.connect(str(URL(f"{ws_url}") / "ws") + f"?clientId={client_id}", timeout=120)
|
||||
|
||||
# websocket rotate execution status
|
||||
output_images = {}
|
||||
while True:
|
||||
out = ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
if message["type"] == "executing":
|
||||
data = message["data"]
|
||||
if data["node"] is None and data["prompt_id"] == prompt_id:
|
||||
break # Execution is done
|
||||
elif message["type"] == "status":
|
||||
data = message["data"]
|
||||
if data["status"]["exec_info"]["queue_remaining"] == 0 and data.get("sid"):
|
||||
break # Execution is done
|
||||
else:
|
||||
continue # previews are binary data
|
||||
|
||||
# download image when execution finished
|
||||
history = self.get_history(base_url, prompt_id)[prompt_id]
|
||||
for o in history["outputs"]:
|
||||
for node_id in history["outputs"]:
|
||||
node_output = history["outputs"][node_id]
|
||||
if "images" in node_output:
|
||||
images_output = []
|
||||
for image in node_output["images"]:
|
||||
image_data = self.download_image(base_url, image["filename"], image["subfolder"], image["type"])
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
|
||||
ws.close()
|
||||
|
||||
return output_images
|
||||
|
||||
def text2img(
|
||||
self,
|
||||
base_url: str,
|
||||
model: str,
|
||||
model_type: str,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
width: int,
|
||||
height: int,
|
||||
steps: int,
|
||||
sampler_name: str,
|
||||
scheduler: str,
|
||||
cfg: float,
|
||||
lora_list: list,
|
||||
lora_strength_list: list,
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
generate image
|
||||
"""
|
||||
if not SD_TXT2IMG_OPTIONS:
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
with open(os.path.join(current_dir, "txt2img.json")) as file:
|
||||
SD_TXT2IMG_OPTIONS.update(json.load(file))
|
||||
|
||||
draw_options = deepcopy(SD_TXT2IMG_OPTIONS)
|
||||
draw_options["3"]["inputs"]["steps"] = steps
|
||||
draw_options["3"]["inputs"]["sampler_name"] = sampler_name
|
||||
draw_options["3"]["inputs"]["scheduler"] = scheduler
|
||||
draw_options["3"]["inputs"]["cfg"] = cfg
|
||||
# generate different image when using same prompt next time
|
||||
draw_options["3"]["inputs"]["seed"] = random.randint(0, 100000000)
|
||||
draw_options["4"]["inputs"]["ckpt_name"] = model
|
||||
draw_options["5"]["inputs"]["width"] = width
|
||||
draw_options["5"]["inputs"]["height"] = height
|
||||
draw_options["6"]["inputs"]["text"] = prompt
|
||||
draw_options["7"]["inputs"]["text"] = negative_prompt
|
||||
# if the model is SD3 or FLUX series, the Latent class should be corresponding to SD3 Latent
|
||||
if model_type in {ModelType.SD3.name, ModelType.FLUX.name}:
|
||||
draw_options["5"]["class_type"] = "EmptySD3LatentImage"
|
||||
|
||||
if lora_list:
|
||||
# last Lora node link to KSampler node
|
||||
draw_options["3"]["inputs"]["model"][0] = "10"
|
||||
# last Lora node link to positive and negative Clip node
|
||||
draw_options["6"]["inputs"]["clip"][0] = "10"
|
||||
draw_options["7"]["inputs"]["clip"][0] = "10"
|
||||
# every Lora node link to next Lora node, and Checkpoints node link to first Lora node
|
||||
for i, (lora, strength) in enumerate(zip(lora_list, lora_strength_list), 10):
|
||||
if i - 10 == len(lora_list) - 1:
|
||||
next_node_id = "4"
|
||||
else:
|
||||
next_node_id = str(i + 1)
|
||||
lora_node = deepcopy(LORA_NODE)
|
||||
lora_node["inputs"]["lora_name"] = lora
|
||||
lora_node["inputs"]["strength_model"] = strength
|
||||
lora_node["inputs"]["strength_clip"] = strength
|
||||
lora_node["inputs"]["model"][0] = next_node_id
|
||||
lora_node["inputs"]["clip"][0] = next_node_id
|
||||
draw_options[str(i)] = lora_node
|
||||
|
||||
# FLUX need to add FluxGuidance Node
|
||||
if model_type == ModelType.FLUX.name:
|
||||
last_node_id = str(10 + len(lora_list))
|
||||
draw_options[last_node_id] = deepcopy(FluxGuidanceNode)
|
||||
draw_options[last_node_id]["inputs"]["conditioning"][0] = "6"
|
||||
draw_options["3"]["inputs"]["positive"][0] = last_node_id
|
||||
|
||||
try:
|
||||
client_id = str(uuid.uuid4())
|
||||
result = self.queue_prompt_image(base_url, client_id, prompt=draw_options)
|
||||
|
||||
# get first image
|
||||
image = b""
|
||||
for node in result:
|
||||
for img in result[node]:
|
||||
if img:
|
||||
image = img
|
||||
break
|
||||
|
||||
return self.create_blob_message(
|
||||
blob=image, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to generate image: {str(e)}")
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
parameters = [
|
||||
ToolParameter(
|
||||
name="prompt",
|
||||
label=I18nObject(en_US="Prompt", zh_Hans="Prompt"),
|
||||
human_description=I18nObject(
|
||||
en_US="Image prompt, you can check the official documentation of Stable Diffusion",
|
||||
zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description="Image prompt of Stable Diffusion, you should describe the image "
|
||||
"you want to generate as a list of words as possible as detailed, "
|
||||
"the prompt must be written in English.",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
if self.runtime.credentials:
|
||||
try:
|
||||
models = self.get_checkpoints()
|
||||
if len(models) != 0:
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name="model",
|
||||
label=I18nObject(en_US="Model", zh_Hans="Model"),
|
||||
human_description=I18nObject(
|
||||
en_US="Model of Stable Diffusion or FLUX, "
|
||||
"you can check the official documentation of Stable Diffusion or FLUX",
|
||||
zh_Hans="Stable Diffusion 或者 FLUX 的模型,您可以查看 Stable Diffusion 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description="Model of Stable Diffusion or FLUX, "
|
||||
"you can check the official documentation of Stable Diffusion or FLUX",
|
||||
required=True,
|
||||
default=models[0],
|
||||
options=[
|
||||
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models
|
||||
],
|
||||
)
|
||||
)
|
||||
loras = self.get_loras()
|
||||
if len(loras) != 0:
|
||||
for n in range(1, 4):
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=f"lora_{n}",
|
||||
label=I18nObject(en_US=f"Lora {n}", zh_Hans=f"Lora {n}"),
|
||||
human_description=I18nObject(
|
||||
en_US="Lora of Stable Diffusion, "
|
||||
"you can check the official documentation of Stable Diffusion",
|
||||
zh_Hans="Stable Diffusion 的 Lora 模型,您可以查看 Stable Diffusion 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description="Lora of Stable Diffusion, "
|
||||
"you can check the official documentation of "
|
||||
"Stable Diffusion",
|
||||
required=False,
|
||||
options=[
|
||||
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in loras
|
||||
],
|
||||
)
|
||||
)
|
||||
sample_methods, schedulers = self.get_sample_methods()
|
||||
if len(sample_methods) != 0:
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name="sampler_name",
|
||||
label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"),
|
||||
human_description=I18nObject(
|
||||
en_US="Sampling method of Stable Diffusion, "
|
||||
"you can check the official documentation of Stable Diffusion",
|
||||
zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description="Sampling method of Stable Diffusion, "
|
||||
"you can check the official documentation of Stable Diffusion",
|
||||
required=True,
|
||||
default=sample_methods[0],
|
||||
options=[
|
||||
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i))
|
||||
for i in sample_methods
|
||||
],
|
||||
)
|
||||
)
|
||||
if len(schedulers) != 0:
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name="scheduler",
|
||||
label=I18nObject(en_US="Scheduler", zh_Hans="Scheduler"),
|
||||
human_description=I18nObject(
|
||||
en_US="Scheduler of Stable Diffusion, "
|
||||
"you can check the official documentation of Stable Diffusion",
|
||||
zh_Hans="Stable Diffusion 的Scheduler,您可以查看 Stable Diffusion 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description="Scheduler of Stable Diffusion, "
|
||||
"you can check the official documentation of Stable Diffusion",
|
||||
required=True,
|
||||
default=schedulers[0],
|
||||
options=[
|
||||
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in schedulers
|
||||
],
|
||||
)
|
||||
)
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name="model_type",
|
||||
label=I18nObject(en_US="Model Type", zh_Hans="Model Type"),
|
||||
human_description=I18nObject(
|
||||
en_US="Model Type of Stable Diffusion or Flux, "
|
||||
"you can check the official documentation of Stable Diffusion or Flux",
|
||||
zh_Hans="Stable Diffusion 或 FLUX 的模型类型,"
|
||||
"您可以查看 Stable Diffusion 或 Flux 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description="Model Type of Stable Diffusion or Flux, "
|
||||
"you can check the official documentation of Stable Diffusion or Flux",
|
||||
required=True,
|
||||
default=ModelType.SD15.name,
|
||||
options=[
|
||||
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i))
|
||||
for i in ModelType.__members__
|
||||
],
|
||||
)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
return parameters
|
||||
|
|
@ -1,212 +0,0 @@
|
|||
identity:
|
||||
name: txt2img workflow
|
||||
author: Qun
|
||||
label:
|
||||
en_US: Txt2Img Workflow
|
||||
zh_Hans: Txt2Img Workflow
|
||||
pt_BR: Txt2Img Workflow
|
||||
description:
|
||||
human:
|
||||
en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader.
|
||||
zh_Hans: 一个预定义的 ComfyUI 工作流,可以使用一个模型和最多3个loras来生成图像。支持包含文本编码器/clip的SD1.5、SDXL、SD3和FLUX,但不支持需要clip加载器的模型。
|
||||
pt_BR: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader.
|
||||
llm: draw the image you want based on your prompt.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of Stable Diffusion or FLUX
|
||||
zh_Hans: 图像提示词,您可以查看 Stable Diffusion 或者 FLUX 的官方文档
|
||||
pt_BR: Image prompt, you can check the official documentation of Stable Diffusion or FLUX
|
||||
llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
||||
form: llm
|
||||
- name: model
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
pt_BR: Model Name
|
||||
human_description:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
pt_BR: Model Name
|
||||
form: form
|
||||
- name: model_type
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Model Type
|
||||
zh_Hans: 模型类型
|
||||
pt_BR: Model Type
|
||||
human_description:
|
||||
en_US: Model Type
|
||||
zh_Hans: 模型类型
|
||||
pt_BR: Model Type
|
||||
form: form
|
||||
- name: lora_1
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora 1
|
||||
zh_Hans: Lora 1
|
||||
pt_BR: Lora 1
|
||||
human_description:
|
||||
en_US: Lora 1
|
||||
zh_Hans: Lora 1
|
||||
pt_BR: Lora 1
|
||||
form: form
|
||||
- name: lora_strength_1
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora Strength 1
|
||||
zh_Hans: Lora Strength 1
|
||||
pt_BR: Lora Strength 1
|
||||
human_description:
|
||||
en_US: Lora Strength 1
|
||||
zh_Hans: Lora模型的权重
|
||||
pt_BR: Lora Strength 1
|
||||
form: form
|
||||
- name: steps
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Steps
|
||||
zh_Hans: Steps
|
||||
pt_BR: Steps
|
||||
human_description:
|
||||
en_US: Steps
|
||||
zh_Hans: Steps
|
||||
pt_BR: Steps
|
||||
form: form
|
||||
default: 20
|
||||
- name: width
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Width
|
||||
zh_Hans: Width
|
||||
pt_BR: Width
|
||||
human_description:
|
||||
en_US: Width
|
||||
zh_Hans: Width
|
||||
pt_BR: Width
|
||||
form: form
|
||||
default: 1024
|
||||
- name: height
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Height
|
||||
zh_Hans: Height
|
||||
pt_BR: Height
|
||||
human_description:
|
||||
en_US: Height
|
||||
zh_Hans: Height
|
||||
pt_BR: Height
|
||||
form: form
|
||||
default: 1024
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Negative prompt
|
||||
zh_Hans: Negative prompt
|
||||
pt_BR: Negative prompt
|
||||
human_description:
|
||||
en_US: Negative prompt
|
||||
zh_Hans: Negative prompt
|
||||
pt_BR: Negative prompt
|
||||
form: form
|
||||
default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines
|
||||
- name: cfg
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: CFG Scale
|
||||
zh_Hans: CFG Scale
|
||||
pt_BR: CFG Scale
|
||||
human_description:
|
||||
en_US: CFG Scale
|
||||
zh_Hans: 提示词相关性(CFG Scale)
|
||||
pt_BR: CFG Scale
|
||||
form: form
|
||||
default: 7.0
|
||||
- name: sampler_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Sampling method
|
||||
zh_Hans: Sampling method
|
||||
pt_BR: Sampling method
|
||||
human_description:
|
||||
en_US: Sampling method
|
||||
zh_Hans: Sampling method
|
||||
pt_BR: Sampling method
|
||||
form: form
|
||||
- name: scheduler
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Scheduler
|
||||
zh_Hans: Scheduler
|
||||
pt_BR: Scheduler
|
||||
human_description:
|
||||
en_US: Scheduler
|
||||
zh_Hans: Scheduler
|
||||
pt_BR: Scheduler
|
||||
form: form
|
||||
- name: lora_2
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora 2
|
||||
zh_Hans: Lora 2
|
||||
pt_BR: Lora 2
|
||||
human_description:
|
||||
en_US: Lora 2
|
||||
zh_Hans: Lora 2
|
||||
pt_BR: Lora 2
|
||||
form: form
|
||||
- name: lora_strength_2
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora Strength 2
|
||||
zh_Hans: Lora Strength 2
|
||||
pt_BR: Lora Strength 2
|
||||
human_description:
|
||||
en_US: Lora Strength 2
|
||||
zh_Hans: Lora模型的权重
|
||||
pt_BR: Lora Strength 2
|
||||
form: form
|
||||
- name: lora_3
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora 3
|
||||
zh_Hans: Lora 3
|
||||
pt_BR: Lora 3
|
||||
human_description:
|
||||
en_US: Lora 3
|
||||
zh_Hans: Lora 3
|
||||
pt_BR: Lora 3
|
||||
form: form
|
||||
- name: lora_strength_3
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora Strength 3
|
||||
zh_Hans: Lora Strength 3
|
||||
pt_BR: Lora Strength 3
|
||||
human_description:
|
||||
en_US: Lora Strength 3
|
||||
zh_Hans: Lora模型的权重
|
||||
pt_BR: Lora Strength 3
|
||||
form: form
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
{
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": 156680208700286,
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"ckpt_name": "3dAnimationDiffusion_v10.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "text, watermark",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Generator: Adobe Illustrator 19.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
|
||||
viewBox="0 0 200 130.2" style="enable-background:new 0 0 200 130.2;" xml:space="preserve">
|
||||
<style type="text/css">
|
||||
.st0{fill:#3EB1C8;}
|
||||
.st1{fill:#D8D2C4;}
|
||||
.st2{fill:#4F5858;}
|
||||
.st3{fill:#FFC72C;}
|
||||
.st4{fill:#EF3340;}
|
||||
</style>
|
||||
<g>
|
||||
<polygon class="st0" points="111.8,95.5 111.8,66.8 135.4,59 177.2,73.3 "/>
|
||||
<polygon class="st1" points="153.6,36.8 111.8,51.2 135.4,59 177.2,44.6 "/>
|
||||
<polygon class="st2" points="135.4,59 177.2,44.6 177.2,73.3 "/>
|
||||
<polygon class="st3" points="177.2,0.3 177.2,29 153.6,36.8 111.8,22.5 "/>
|
||||
<polygon class="st4" points="153.6,36.8 111.8,51.2 111.8,22.5 "/>
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<path class="st2" d="M26.3,104.8c-0.5-3.7-4.1-6.5-8.1-6.5c-7.3,0-10.1,6.2-10.1,12.7c0,6.2,2.8,12.4,10.1,12.4
|
||||
c5,0,7.8-3.4,8.4-8.3h7.9c-0.8,9.2-7.2,15.2-16.3,15.2C6.8,130.2,0,121.7,0,111c0-11,6.8-19.6,18.2-19.6c8.2,0,15,4.8,16,13.3
|
||||
H26.3z"/>
|
||||
<path class="st2" d="M37.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
|
||||
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
|
||||
<path class="st2" d="M68.7,101.8c8.5,0,13.9,5.6,13.9,14.2c0,8.5-5.5,14.1-13.9,14.1c-8.4,0-13.9-5.6-13.9-14.1
|
||||
C54.9,107.4,60.3,101.8,68.7,101.8z M68.7,124.5c5,0,6.5-4.3,6.5-8.6c0-4.3-1.5-8.6-6.5-8.6c-5,0-6.5,4.3-6.5,8.6
|
||||
C62.2,120.2,63.8,124.5,68.7,124.5z"/>
|
||||
<path class="st2" d="M91.2,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2c-4.3-0.9-8.5-2.4-8.5-7.2
|
||||
c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5c0,2.6,4.2,3,8.4,4
|
||||
c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H91.2z"/>
|
||||
<path class="st2" d="M118.1,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2
|
||||
c-4.3-0.9-8.5-2.4-8.5-7.2c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5
|
||||
c0,2.6,4.2,3,8.4,4c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H118.1z"/>
|
||||
<path class="st2" d="M138.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
|
||||
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
|
||||
<path class="st2" d="M163.7,117.7c0.2,4.7,2.5,6.8,6.6,6.8c3,0,5.3-1.8,5.8-3.5h6.5c-2.1,6.3-6.5,9-12.6,9
|
||||
c-8.5,0-13.7-5.8-13.7-14.1c0-8,5.6-14.2,13.7-14.2c9.1,0,13.6,7.7,13,15.9H163.7z M175.7,113.1c-0.7-3.7-2.3-5.7-5.9-5.7
|
||||
c-4.7,0-6,3.6-6.1,5.7H175.7z"/>
|
||||
<path class="st2" d="M187.2,107.5h-4.4v-4.9h4.4v-2.1c0-4.7,3-8.2,9-8.2c1.3,0,2.6,0.2,3.9,0.2V98c-0.9-0.1-1.8-0.2-2.7-0.2
|
||||
c-2,0-2.8,0.8-2.8,3.1v1.6h5.1v4.9h-5.1v21.9h-7.4V107.5z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 3.0 KiB |