mirror of
https://github.com/langgenius/dify.git
synced 2026-05-11 23:18:39 +08:00
feat(inner-api): resolve runtime credentials
This commit is contained in:
parent
aa1430aa16
commit
253888f758
@ -16,6 +16,7 @@ api = ExternalApi(
|
||||
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
|
||||
|
||||
from . import mail as _mail
|
||||
from . import runtime_credentials as _runtime_credentials
|
||||
from .app import dsl as _app_dsl
|
||||
from .plugin import plugin as _plugin
|
||||
from .workspace import workspace as _workspace
|
||||
@ -26,6 +27,7 @@ __all__ = [
|
||||
"_app_dsl",
|
||||
"_mail",
|
||||
"_plugin",
|
||||
"_runtime_credentials",
|
||||
"_workspace",
|
||||
"api",
|
||||
"bp",
|
||||
|
||||
129
api/controllers/inner_api/runtime_credentials.py
Normal file
129
api/controllers/inner_api/runtime_credentials.py
Normal file
@ -0,0 +1,129 @@
|
||||
"""Inner API endpoints for runtime credential resolution.
|
||||
|
||||
Called by Enterprise while resolving AppRunner runtime artifacts. The endpoint
|
||||
returns decrypted model credentials for in-memory runtime use only.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from core.helper import encrypter
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from extensions.ext_database import db
|
||||
from models.provider import ProviderCredential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InnerRuntimeModelCredentialResolveItem(BaseModel):
|
||||
credential_id: str = Field(description="Provider credential id")
|
||||
provider: str = Field(description="Runtime provider identifier, for example langgenius/openai/openai")
|
||||
vendor: str | None = Field(default=None, description="Model vendor, for example openai")
|
||||
plugin_unique_identifier: str | None = Field(default=None, description="Runtime plugin identifier")
|
||||
|
||||
|
||||
class InnerRuntimeModelCredentialsResolvePayload(BaseModel):
|
||||
tenant_id: str = Field(description="Workspace id")
|
||||
credentials: list[InnerRuntimeModelCredentialResolveItem] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_model(inner_api_ns, InnerRuntimeModelCredentialsResolvePayload)
|
||||
|
||||
|
||||
@inner_api_ns.route("/enterprise/runtime/model-credentials:resolve")
|
||||
class EnterpriseRuntimeModelCredentialsResolve(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
@inner_api_ns.doc(
|
||||
"enterprise_runtime_model_credentials_resolve",
|
||||
responses={
|
||||
200: "Credentials resolved",
|
||||
400: "Invalid request or credential config",
|
||||
404: "Provider or credential not found",
|
||||
},
|
||||
)
|
||||
@inner_api_ns.expect(inner_api_ns.models[InnerRuntimeModelCredentialsResolvePayload.__name__])
|
||||
def post(self):
|
||||
args = InnerRuntimeModelCredentialsResolvePayload.model_validate(inner_api_ns.payload or {})
|
||||
if not args.credentials:
|
||||
return {"model_credentials": []}, 200
|
||||
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=args.tenant_id)
|
||||
provider_configurations = provider_manager.get_configurations(args.tenant_id)
|
||||
|
||||
resolved: list[dict[str, Any]] = []
|
||||
for item in args.credentials:
|
||||
provider_configuration = provider_configurations.get(item.provider)
|
||||
if provider_configuration is None:
|
||||
return {"message": f"provider '{item.provider}' not found"}, 404
|
||||
|
||||
provider_schema = provider_configuration.provider.provider_credential_schema
|
||||
secret_variables = provider_configuration.extract_secret_variables(
|
||||
provider_schema.credential_form_schemas if provider_schema else []
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == item.credential_id,
|
||||
ProviderCredential.tenant_id == args.tenant_id,
|
||||
ProviderCredential.provider_name.in_(provider_configuration._get_provider_names()),
|
||||
)
|
||||
credential = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if credential is None or not credential.encrypted_config:
|
||||
return {"message": f"credential '{item.credential_id}' not found"}, 404
|
||||
|
||||
try:
|
||||
values = json.loads(credential.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
return {"message": f"credential '{item.credential_id}' has invalid config"}, 400
|
||||
if not isinstance(values, dict):
|
||||
return {"message": f"credential '{item.credential_id}' has invalid config"}, 400
|
||||
|
||||
for key in secret_variables:
|
||||
value = values.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
try:
|
||||
values[key] = encrypter.decrypt_token(tenant_id=args.tenant_id, token=value)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"failed to resolve runtime model credential",
|
||||
extra={
|
||||
"credential_id": item.credential_id,
|
||||
"provider": item.provider,
|
||||
"tenant_id": args.tenant_id,
|
||||
"error": type(exc).__name__,
|
||||
},
|
||||
)
|
||||
return {"message": f"credential '{item.credential_id}' decrypt failed"}, 400
|
||||
|
||||
resolved.append(
|
||||
{
|
||||
"credential_id": item.credential_id,
|
||||
"provider": item.provider,
|
||||
"vendor": item.vendor or _vendor_from_provider(item.provider),
|
||||
"plugin_unique_identifier": item.plugin_unique_identifier,
|
||||
"values": values,
|
||||
}
|
||||
)
|
||||
|
||||
return {"model_credentials": resolved}, 200
|
||||
|
||||
|
||||
def _vendor_from_provider(provider: str) -> str:
|
||||
provider = provider.strip("/")
|
||||
if not provider:
|
||||
return ""
|
||||
return provider.rsplit("/", 1)[-1]
|
||||
@ -0,0 +1,105 @@
|
||||
"""Unit tests for runtime credential inner API."""
|
||||
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from controllers.inner_api.runtime_credentials import (
|
||||
EnterpriseRuntimeModelCredentialsResolve,
|
||||
InnerRuntimeModelCredentialsResolvePayload,
|
||||
)
|
||||
|
||||
|
||||
def test_runtime_model_credentials_payload_accepts_items():
|
||||
payload = InnerRuntimeModelCredentialsResolvePayload.model_validate(
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"credentials": [
|
||||
{
|
||||
"credential_id": "credential-1",
|
||||
"provider": "langgenius/openai/openai",
|
||||
"vendor": "openai",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
assert payload.tenant_id == "tenant-1"
|
||||
assert payload.credentials[0].provider == "langgenius/openai/openai"
|
||||
|
||||
|
||||
@patch("controllers.inner_api.runtime_credentials.encrypter.decrypt_token")
|
||||
@patch("controllers.inner_api.runtime_credentials.db")
|
||||
@patch("controllers.inner_api.runtime_credentials.Session")
|
||||
@patch("controllers.inner_api.runtime_credentials.create_plugin_provider_manager")
|
||||
def test_runtime_model_credentials_resolve_returns_decrypted_values(
|
||||
mock_provider_manager_factory,
|
||||
mock_session_cls,
|
||||
mock_db,
|
||||
mock_decrypt_token,
|
||||
app: Flask,
|
||||
):
|
||||
provider_configuration = MagicMock()
|
||||
provider_configuration.provider.provider_credential_schema.credential_form_schemas = []
|
||||
provider_configuration.extract_secret_variables.return_value = ["openai_api_key"]
|
||||
provider_configuration._get_provider_names.return_value = ["langgenius/openai/openai", "openai"]
|
||||
|
||||
provider_configurations = MagicMock()
|
||||
provider_configurations.get.return_value = provider_configuration
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value = provider_configurations
|
||||
mock_provider_manager_factory.return_value = provider_manager
|
||||
|
||||
credential = MagicMock()
|
||||
credential.encrypted_config = '{"openai_api_key":"encrypted","api_base":"https://api.openai.com/v1"}'
|
||||
session = MagicMock()
|
||||
session.__enter__.return_value = session
|
||||
session.__exit__.return_value = False
|
||||
session.execute.return_value.scalar_one_or_none.return_value = credential
|
||||
mock_session_cls.return_value = session
|
||||
mock_db.engine = MagicMock()
|
||||
mock_decrypt_token.return_value = "sk-test"
|
||||
|
||||
handler = EnterpriseRuntimeModelCredentialsResolve()
|
||||
unwrapped = inspect.unwrap(handler.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.runtime_credentials.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"tenant_id": "tenant-1",
|
||||
"credentials": [
|
||||
{
|
||||
"credential_id": "credential-1",
|
||||
"provider": "langgenius/openai/openai",
|
||||
"vendor": "openai",
|
||||
}
|
||||
],
|
||||
}
|
||||
body, status_code = unwrapped(handler)
|
||||
|
||||
assert status_code == 200
|
||||
assert body["model_credentials"][0]["values"]["openai_api_key"] == "sk-test"
|
||||
assert body["model_credentials"][0]["values"]["api_base"] == "https://api.openai.com/v1"
|
||||
mock_decrypt_token.assert_called_once_with(tenant_id="tenant-1", token="encrypted")
|
||||
|
||||
|
||||
@patch("controllers.inner_api.runtime_credentials.create_plugin_provider_manager")
|
||||
def test_runtime_model_credentials_resolve_rejects_unknown_provider(mock_provider_manager_factory, app: Flask):
|
||||
provider_configurations = MagicMock()
|
||||
provider_configurations.get.return_value = None
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value = provider_configurations
|
||||
mock_provider_manager_factory.return_value = provider_manager
|
||||
|
||||
handler = EnterpriseRuntimeModelCredentialsResolve()
|
||||
unwrapped = inspect.unwrap(handler.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.runtime_credentials.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"tenant_id": "tenant-1",
|
||||
"credentials": [{"credential_id": "credential-1", "provider": "missing"}],
|
||||
}
|
||||
body, status_code = unwrapped(handler)
|
||||
|
||||
assert status_code == 404
|
||||
assert "provider" in body["message"]
|
||||
Loading…
Reference in New Issue
Block a user