mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 04:36:31 +08:00
181 lines
6.8 KiB
Python
181 lines
6.8 KiB
Python
import logging
|
|
import os
|
|
from collections.abc import Mapping
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from configs import dify_config
|
|
from core.helper.trace_id_helper import generate_traceparent_header
|
|
from services.errors.enterprise import (
|
|
EnterpriseAPIBadRequestError,
|
|
EnterpriseAPIError,
|
|
EnterpriseAPIForbiddenError,
|
|
EnterpriseAPINotFoundError,
|
|
EnterpriseAPIUnauthorizedError,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Headers recognised by dify-enterprise's /inner/api/rbac/* endpoints.
|
|
# Keep in sync with pkg/enterprise/service/rbac_inner_handlers.go.
|
|
INNER_TENANT_ID_HEADER = "X-Inner-Tenant-Id"
|
|
INNER_ACCOUNT_ID_HEADER = "X-Inner-Account-Id"
|
|
|
|
|
|
class BaseRequest:
|
|
proxies: Mapping[str, str] | None = {
|
|
"http": "",
|
|
"https": "",
|
|
}
|
|
base_url = ""
|
|
secret_key = ""
|
|
secret_key_header = ""
|
|
|
|
@classmethod
|
|
def _build_mounts(cls) -> dict[str, httpx.BaseTransport] | None:
|
|
if not cls.proxies:
|
|
return None
|
|
|
|
mounts: dict[str, httpx.BaseTransport] = {}
|
|
for scheme, value in cls.proxies.items():
|
|
if not value:
|
|
continue
|
|
key = f"{scheme}://" if not scheme.endswith("://") else scheme
|
|
mounts[key] = httpx.HTTPTransport(proxy=value)
|
|
return mounts or None
|
|
|
|
@classmethod
|
|
def send_request(
|
|
cls,
|
|
method: str,
|
|
endpoint: str,
|
|
json: Any | None = None,
|
|
params: Mapping[str, Any] | None = None,
|
|
*,
|
|
timeout: float | httpx.Timeout | None = None,
|
|
raise_for_status: bool = False,
|
|
extra_headers: Mapping[str, str] | None = None,
|
|
) -> Any:
|
|
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
|
|
if extra_headers:
|
|
# Explicitly ignore empty values so callers can pass optional
|
|
# headers (e.g. `X-Inner-Account-Id`) without having to branch.
|
|
for key, value in extra_headers.items():
|
|
if value is None or value == "":
|
|
continue
|
|
headers[key] = value
|
|
url = f"{cls.base_url}{endpoint}"
|
|
mounts = cls._build_mounts()
|
|
|
|
try:
|
|
# ensure traceparent even when OTEL is disabled
|
|
traceparent = generate_traceparent_header()
|
|
if traceparent:
|
|
headers["traceparent"] = traceparent
|
|
except Exception:
|
|
logger.debug("Failed to generate traceparent header", exc_info=True)
|
|
|
|
with httpx.Client(mounts=mounts) as client:
|
|
# IMPORTANT:
|
|
# - In httpx, passing timeout=None disables timeouts (infinite) and overrides the library default.
|
|
# - To preserve httpx's default timeout behavior for existing call sites, only pass the kwarg when set.
|
|
request_kwargs: dict[str, Any] = {"json": json, "params": params, "headers": headers}
|
|
if timeout is not None:
|
|
request_kwargs["timeout"] = timeout
|
|
|
|
response = client.request(method, url, **request_kwargs)
|
|
|
|
# Validate HTTP status and raise domain-specific errors
|
|
if not response.is_success:
|
|
cls._handle_error_response(response)
|
|
return response.json()
|
|
|
|
@classmethod
|
|
def _handle_error_response(cls, response: httpx.Response) -> None:
|
|
"""
|
|
Handle non-2xx HTTP responses by raising appropriate domain errors.
|
|
|
|
Attempts to extract error message from JSON response body,
|
|
falls back to status text if parsing fails.
|
|
"""
|
|
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
|
|
|
|
# Try to extract error message from JSON response
|
|
try:
|
|
error_data = response.json()
|
|
if isinstance(error_data, dict):
|
|
# Common error response formats:
|
|
# {"error": "...", "message": "..."}
|
|
# {"message": "..."}
|
|
# {"detail": "..."}
|
|
error_message = (
|
|
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
|
|
)
|
|
except Exception:
|
|
# If JSON parsing fails, use the default message
|
|
logger.debug(
|
|
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
|
|
)
|
|
|
|
# Raise specific error based on status code
|
|
if response.status_code == 400:
|
|
raise EnterpriseAPIBadRequestError(error_message)
|
|
elif response.status_code == 401:
|
|
raise EnterpriseAPIUnauthorizedError(error_message)
|
|
elif response.status_code == 403:
|
|
raise EnterpriseAPIForbiddenError(error_message)
|
|
elif response.status_code == 404:
|
|
raise EnterpriseAPINotFoundError(error_message)
|
|
else:
|
|
raise EnterpriseAPIError(error_message, status_code=response.status_code)
|
|
|
|
|
|
class EnterpriseRequest(BaseRequest):
|
|
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
|
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
|
|
secret_key_header = "Enterprise-Api-Secret-Key"
|
|
|
|
@classmethod
|
|
def send_inner_rbac_request(
|
|
cls,
|
|
method: str,
|
|
endpoint: str,
|
|
*,
|
|
tenant_id: str,
|
|
account_id: str | None = None,
|
|
json: Any | None = None,
|
|
params: Mapping[str, Any] | None = None,
|
|
timeout: float | httpx.Timeout | None = None,
|
|
) -> Any:
|
|
"""Call an /inner/api/rbac/* endpoint on dify-enterprise.
|
|
|
|
Inner RBAC endpoints require three headers on top of the standard
|
|
Enterprise-Api-Secret-Key: the tenant the call targets and (optionally)
|
|
the account acting on behalf of the workspace. This helper centralises
|
|
both the assertions and the header wiring so callers only have to
|
|
supply business payload.
|
|
"""
|
|
if not dify_config.ENTERPRISE_ENABLED:
|
|
raise EnterpriseAPIError("Enterprise edition is not enabled")
|
|
if not tenant_id:
|
|
raise ValueError("tenant_id must be provided for inner RBAC requests")
|
|
|
|
inner_headers: dict[str, str] = {INNER_TENANT_ID_HEADER: tenant_id}
|
|
if account_id:
|
|
inner_headers[INNER_ACCOUNT_ID_HEADER] = account_id
|
|
return cls.send_request(
|
|
method,
|
|
endpoint,
|
|
json=json,
|
|
params=params,
|
|
timeout=timeout,
|
|
extra_headers=inner_headers,
|
|
)
|
|
|
|
|
|
class EnterprisePluginManagerRequest(BaseRequest):
|
|
base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL")
|
|
secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY")
|
|
secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key"
|