mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
parent
ce68f2cdc6
commit
47b9d48f70
@ -205,16 +205,160 @@ class ToolManager:
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
match provider_type:
|
||||
case ToolProviderType.BUILT_IN:
|
||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
|
||||
builtin_tool = provider_controller.get_tool(tool_name)
|
||||
if not builtin_tool:
|
||||
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
|
||||
builtin_tool = provider_controller.get_tool(tool_name)
|
||||
if not builtin_tool:
|
||||
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
|
||||
|
||||
if not provider_controller.need_credentials:
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
builtin_provider = None
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
if is_valid_uuid(credential_id):
|
||||
try:
|
||||
builtin_provider_stmt = select(BuiltinToolProvider).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
builtin_provider = db.session.scalar(builtin_provider_stmt)
|
||||
except Exception as e:
|
||||
builtin_provider = None
|
||||
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
||||
|
||||
if builtin_provider is None:
|
||||
with Session(db.engine) as session:
|
||||
builtin_provider = session.scalar(
|
||||
sa.select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
builtin_provider = db.session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
credential_id=builtin_provider.id,
|
||||
provider=provider_id,
|
||||
credential_type=PluginCredentialType.TOOL,
|
||||
check_existence=False,
|
||||
)
|
||||
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(
|
||||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
|
||||
|
||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||
# TODO: circular import
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
tool_provider = ToolProviderID(provider_id)
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=builtin_provider.user_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
# update the credentials
|
||||
builtin_provider.encrypted_credentials = json.dumps(
|
||||
encrypter.encrypt(refreshed_credentials.credentials)
|
||||
)
|
||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||
db.session.commit()
|
||||
decrypted_credentials = refreshed_credentials.credentials
|
||||
cache.delete()
|
||||
|
||||
if not provider_controller.need_credentials:
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=builtin_provider.credential_type,
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
|
||||
case ToolProviderType.API:
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=api_provider,
|
||||
)
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(encrypter.decrypt(credentials)),
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
case ToolProviderType.WORKFLOW:
|
||||
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
workflow_provider = session.scalar(workflow_provider_stmt)
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
@ -223,177 +367,28 @@ class ToolManager:
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
builtin_provider = None
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
# get specific credentials
|
||||
if is_valid_uuid(credential_id):
|
||||
try:
|
||||
builtin_provider_stmt = select(BuiltinToolProvider).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
builtin_provider = db.session.scalar(builtin_provider_stmt)
|
||||
except Exception as e:
|
||||
builtin_provider = None
|
||||
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
|
||||
# if the provider has been deleted, raise an error
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
||||
|
||||
# fallback to the default provider
|
||||
if builtin_provider is None:
|
||||
# use the default provider
|
||||
with Session(db.engine) as session:
|
||||
builtin_provider = session.scalar(
|
||||
sa.select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
builtin_provider = db.session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
|
||||
# check if the credential is allowed to be used
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
credential_id=builtin_provider.id,
|
||||
provider=provider_id,
|
||||
credential_type=PluginCredentialType.TOOL,
|
||||
check_existence=False,
|
||||
)
|
||||
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(
|
||||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
# decrypt the credentials
|
||||
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
|
||||
|
||||
# check if the credentials is expired
|
||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||
# TODO: circular import
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
# refresh the credentials
|
||||
tool_provider = ToolProviderID(provider_id)
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
# refresh the credentials
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=builtin_provider.user_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
# update the credentials
|
||||
builtin_provider.encrypted_credentials = json.dumps(
|
||||
encrypter.encrypt(refreshed_credentials.credentials)
|
||||
)
|
||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||
db.session.commit()
|
||||
decrypted_credentials = refreshed_credentials.credentials
|
||||
cache.delete()
|
||||
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=builtin_provider.credential_type,
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
|
||||
elif provider_type == ToolProviderType.API:
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=api_provider,
|
||||
)
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(encrypter.decrypt(credentials)),
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
workflow_provider = session.scalar(workflow_provider_stmt)
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
elif provider_type == ToolProviderType.APP:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||
runtime = getattr(plugin_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return plugin_tool
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
||||
runtime = getattr(mcp_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return mcp_tool
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
case ToolProviderType.APP:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
case ToolProviderType.PLUGIN:
|
||||
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||
runtime = getattr(plugin_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return plugin_tool
|
||||
case ToolProviderType.MCP:
|
||||
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
||||
runtime = getattr(mcp_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return mcp_tool
|
||||
case ToolProviderType.DATASET_RETRIEVAL:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
case _:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
|
||||
|
||||
@classmethod
|
||||
def get_agent_tool_runtime(
|
||||
@ -1027,31 +1022,31 @@ class ToolManager:
|
||||
:param provider_id: the id of the provider
|
||||
:return:
|
||||
"""
|
||||
provider_type = provider_type
|
||||
provider_id = provider_id
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
match provider_type:
|
||||
case ToolProviderType.BUILT_IN:
|
||||
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
return cls.generate_builtin_tool_icon_url(provider_id)
|
||||
case ToolProviderType.API:
|
||||
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
|
||||
case ToolProviderType.WORKFLOW:
|
||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||
case ToolProviderType.PLUGIN:
|
||||
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
return cls.generate_builtin_tool_icon_url(provider_id)
|
||||
elif provider_type == ToolProviderType.API:
|
||||
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
raise ValueError(f"plugin provider {provider_id} not found")
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
case ToolProviderType.MCP:
|
||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||
case ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
case _:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_parameters_type(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user