mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
feat(dify-cli): session level tool white list
This commit is contained in:
parent
a9e1394011
commit
89eb7b17db
@ -1,3 +1,4 @@
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.cli_api import cli_api_ns
|
||||
@ -15,6 +16,8 @@ from core.plugin.entities.request import (
|
||||
RequestInvokeTool,
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.session.cli_api import CliContext
|
||||
from core.skill.entities import ToolInvocationRequest
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from libs.helper import length_prefixed_response
|
||||
from models import Account, Tenant
|
||||
@ -23,9 +26,9 @@ from models.model import EndUser
|
||||
|
||||
@cli_api_ns.route("/invoke/llm")
|
||||
class CliInvokeLLMApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@cli_api_only
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
|
||||
def generator():
|
||||
@ -37,17 +40,34 @@ class CliInvokeLLMApi(Resource):
|
||||
|
||||
@cli_api_ns.route("/invoke/tool")
|
||||
class CliInvokeToolApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@cli_api_only
|
||||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestInvokeTool,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
tool_type = ToolProviderType.value_of(payload.tool_type)
|
||||
|
||||
request = ToolInvocationRequest(
|
||||
tool_type=tool_type,
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
credential_id=payload.credential_id,
|
||||
)
|
||||
if cli_context.tool_access and not cli_context.tool_access.is_allowed(request):
|
||||
abort(403)
|
||||
|
||||
def generator():
|
||||
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||
PluginToolBackwardsInvocation.invoke_tool(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
tool_type=ToolProviderType.value_of(payload.tool_type),
|
||||
tool_type=tool_type,
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
tool_parameters=payload.tool_parameters,
|
||||
@ -60,9 +80,9 @@ class CliInvokeToolApi(Resource):
|
||||
|
||||
@cli_api_ns.route("/invoke/app")
|
||||
class CliInvokeAppApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@cli_api_only
|
||||
@plugin_data(payload_type=RequestInvokeApp)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
|
||||
response = PluginAppBackwardsInvocation.invoke_app(
|
||||
@ -81,9 +101,9 @@ class CliInvokeAppApi(Resource):
|
||||
|
||||
@cli_api_ns.route("/upload/file/request")
|
||||
class CliUploadFileRequestApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@cli_api_only
|
||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
||||
# generate signed url
|
||||
@ -98,9 +118,9 @@ class CliUploadFileRequestApi(Resource):
|
||||
|
||||
@cli_api_ns.route("/fetch/tools/list")
|
||||
class CliFetchToolsListApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@cli_api_only
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant):
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@ -2,12 +2,12 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask import current_app, g, request
|
||||
from flask_login import user_logged_in
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.session.cli_api import CliApiSession, CliApiSessionManager
|
||||
from core.session.cli_api import CliApiSession, CliContext
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Tenant
|
||||
@ -75,22 +75,13 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
def get_cli_user_tenant(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
session_id = request.headers.get("X-Cli-Api-Session-Id")
|
||||
session: CliApiSession | None = getattr(g, "cli_api_session", None)
|
||||
if session is None:
|
||||
raise ValueError("session not found")
|
||||
|
||||
if session_id:
|
||||
session: CliApiSession | None = CliApiSessionManager().get(session_id)
|
||||
if not session:
|
||||
raise ValueError("session not found")
|
||||
user_id = session.user_id
|
||||
tenant_id = session.tenant_id
|
||||
|
||||
else:
|
||||
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
|
||||
user_id = payload.user_id
|
||||
tenant_id = payload.tenant_id
|
||||
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required")
|
||||
user_id = session.user_id
|
||||
tenant_id = session.tenant_id
|
||||
cli_context = CliContext.model_validate(session.context)
|
||||
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
@ -110,11 +101,10 @@ def get_cli_user_tenant(view_func: Callable[P, R]):
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
kwargs["cli_context"] = cli_context
|
||||
|
||||
user = get_user(tenant_id, user_id)
|
||||
kwargs["user_model"] = user
|
||||
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
||||
current_app.login_manager._update_request_context_with_user(kwargs["user_model"]) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
@ -5,7 +5,7 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from flask import abort, g, request
|
||||
|
||||
from core.session.cli_api import CliApiSessionManager
|
||||
|
||||
@ -49,6 +49,8 @@ def cli_api_only(view: Callable[P, R]):
|
||||
if not _verify_signature(session.secret, timestamp, body, signature):
|
||||
abort(401)
|
||||
|
||||
g.cli_api_session = session
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
@ -1 +0,0 @@
|
||||
# refactor the package import paths
|
||||
@ -6,7 +6,8 @@ from io import BytesIO
|
||||
from types import TracebackType
|
||||
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.session.cli_api import CliApiSession, CliApiSessionManager
|
||||
from core.session.cli_api import CliApiSession, CliApiSessionManager, CliContext
|
||||
from core.skill.entities import ToolAccessPolicy
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
|
||||
@ -37,6 +38,7 @@ class SandboxBashSession:
|
||||
self._cli_api_session = CliApiSessionManager().create(
|
||||
tenant_id=self._tenant_id,
|
||||
user_id=self._user_id,
|
||||
context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(self._tools)),
|
||||
)
|
||||
if self._tools is not None and not self._tools.is_empty():
|
||||
tools_path = self._setup_node_tools_directory(self._node_id, self._tools, self._cli_api_session)
|
||||
@ -55,7 +57,7 @@ class SandboxBashSession:
|
||||
node_id: str,
|
||||
tools: ToolDependencies,
|
||||
cli_api_session: CliApiSession,
|
||||
) -> str | None:
|
||||
) -> str:
|
||||
node_tools_path = f"{DifyCli.TOOLS_ROOT}/{node_id}"
|
||||
|
||||
vm = self._sandbox.vm
|
||||
|
||||
@ -6,7 +6,8 @@ from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.session.cli_api import CliApiSessionManager
|
||||
from core.session.cli_api import CliApiSessionManager, CliContext
|
||||
from core.skill.entities import ToolAccessPolicy
|
||||
from core.skill.skill_manager import SkillManager
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
|
||||
@ -63,7 +64,11 @@ class DifyCliInitializer(AsyncSandboxInitializer):
|
||||
logger.info("No tools found in bundle for assets_id=%s", self._assets_id)
|
||||
return
|
||||
|
||||
self._cli_api_session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id=self._user_id)
|
||||
self._cli_api_session = CliApiSessionManager().create(
|
||||
tenant_id=self._tenant_id,
|
||||
user_id=self._user_id,
|
||||
context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())),
|
||||
)
|
||||
|
||||
pipeline(vm).add(
|
||||
["mkdir", "-p", DifyCli.GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir"
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import secrets
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.skill.entities import ToolAccessPolicy
|
||||
|
||||
from .session import BaseSession, SessionManager
|
||||
|
||||
@ -10,11 +11,15 @@ class CliApiSession(BaseSession):
|
||||
secret: str = Field(default_factory=lambda: secrets.token_urlsafe(32))
|
||||
|
||||
|
||||
class CliContext(BaseModel):
|
||||
tool_access: ToolAccessPolicy | None = Field(default=None, description="Tool access policy")
|
||||
|
||||
|
||||
class CliApiSessionManager(SessionManager[CliApiSession]):
|
||||
def __init__(self, ttl: int | None = None):
|
||||
super().__init__(key_prefix="cli_api_session", session_class=CliApiSession, ttl=ttl)
|
||||
|
||||
def create(self, tenant_id: str, user_id: str, context: dict[str, Any] | None = None) -> CliApiSession:
|
||||
session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context or {})
|
||||
def create(self, tenant_id: str, user_id: str, context: CliContext) -> CliApiSession:
|
||||
session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context.model_dump(mode="json"))
|
||||
self.save(session)
|
||||
return session
|
||||
|
||||
@ -9,6 +9,7 @@ from .skill_metadata import (
|
||||
ToolFieldConfig,
|
||||
ToolReference,
|
||||
)
|
||||
from .tool_access_policy import ToolAccessPolicy, ToolInvocationRequest, ToolKey
|
||||
from .tool_dependencies import ToolDependencies, ToolDependency
|
||||
|
||||
__all__ = [
|
||||
@ -19,9 +20,12 @@ __all__ = [
|
||||
"SkillDocument",
|
||||
"SkillMetadata",
|
||||
"SourceInfo",
|
||||
"ToolAccessPolicy",
|
||||
"ToolConfiguration",
|
||||
"ToolDependencies",
|
||||
"ToolDependency",
|
||||
"ToolFieldConfig",
|
||||
"ToolInvocationRequest",
|
||||
"ToolKey",
|
||||
"ToolReference",
|
||||
]
|
||||
|
||||
87
api/core/skill/entities/tool_access_policy.py
Normal file
87
api/core/skill/entities/tool_access_policy.py
Normal file
@ -0,0 +1,87 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class ToolKey(BaseModel):
|
||||
"""Immutable identifier for a tool (type + provider + name)."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
tool_type: ToolProviderType
|
||||
provider: str
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ToolInvocationRequest(BaseModel):
|
||||
"""A request to invoke a specific tool with optional credential."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
tool_type: ToolProviderType
|
||||
provider: str
|
||||
tool_name: str
|
||||
credential_id: str | None = None
|
||||
|
||||
@property
|
||||
def key(self) -> ToolKey:
|
||||
return ToolKey(tool_type=self.tool_type, provider=self.provider, tool_name=self.tool_name)
|
||||
|
||||
|
||||
class ToolAccessPolicy(BaseModel):
|
||||
"""
|
||||
Determines whether a tool invocation is allowed based on ToolDependencies.
|
||||
|
||||
Rules:
|
||||
1. Tool must be declared in dependencies or references.
|
||||
2. If references exist for the tool, credential_id must match one of them.
|
||||
3. If no references exist for the tool, credential_id must be None.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
allowed_tools: frozenset[ToolKey] = Field(default_factory=frozenset)
|
||||
credential_ids_by_tool: dict[ToolKey, frozenset[str | None]] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dependencies(cls, deps: ToolDependencies | None) -> "ToolAccessPolicy":
|
||||
if deps is None or deps.is_empty():
|
||||
return cls()
|
||||
|
||||
def to_key(t: ToolProviderType, p: str, n: str) -> ToolKey:
|
||||
return ToolKey(tool_type=t, provider=p, tool_name=n)
|
||||
|
||||
tools: set[ToolKey] = set()
|
||||
tools.update(to_key(dep.type, dep.provider, dep.tool_name) for dep in deps.dependencies)
|
||||
tools.update(to_key(ref.type, ref.provider, ref.tool_name) for ref in deps.references)
|
||||
|
||||
creds: dict[ToolKey, set[str | None]] = {}
|
||||
for ref in deps.references:
|
||||
key = to_key(ref.type, ref.provider, ref.tool_name)
|
||||
creds.setdefault(key, set()).add(ref.credential_id)
|
||||
|
||||
return cls(
|
||||
allowed_tools=frozenset(tools),
|
||||
credential_ids_by_tool={k: frozenset(v) for k, v in creds.items()},
|
||||
)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.allowed_tools) == 0
|
||||
|
||||
def is_allowed(self, request: ToolInvocationRequest) -> bool:
|
||||
"""Check if the tool invocation request is allowed."""
|
||||
|
||||
# If the policy is empty, allow any invocation.
|
||||
if self.is_empty():
|
||||
return True
|
||||
|
||||
if request.key not in self.allowed_tools:
|
||||
return False
|
||||
|
||||
allowed_credentials = self.credential_ids_by_tool.get(request.key)
|
||||
if not allowed_credentials:
|
||||
# No references for this tool: only allow invocation without credential.
|
||||
return request.credential_id is None
|
||||
|
||||
return request.credential_id in allowed_credentials
|
||||
Loading…
Reference in New Issue
Block a user