mirror of https://github.com/langgenius/dify.git
feat: add custom OAuth client setup and enhance datasource provider model with avatar_url
This commit is contained in:
parent
7364d051d2
commit
e97f03c130
|
|
@ -81,12 +81,13 @@ class DatasourceOAuthCallback(Resource):
|
|||
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
plugin_oauth_config = (
|
||||
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
oauth_client_params = datasource_provider_service.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
if not plugin_oauth_config:
|
||||
if not oauth_client_params:
|
||||
raise NotFound()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
oauth_handler = OAuthHandler()
|
||||
|
|
@ -96,10 +97,9 @@ class DatasourceOAuthCallback(Resource):
|
|||
plugin_id=plugin_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=plugin_oauth_config.system_credentials,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.add_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
|
|
@ -205,8 +205,28 @@ class DatasourceAuthListApi(Resource):
|
|||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enabled", False),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
|
||||
|
|
@ -229,3 +249,8 @@ api.add_resource(
|
|||
DatasourceAuthListApi,
|
||||
"/auth/plugin/datasource/list",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthOauthCustomClient,
|
||||
"/auth/plugin/datasource/<path:provider_id>/custom-client",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -361,4 +361,4 @@ class PluginDatasourceManager(BasePluginClient):
|
|||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
@ -124,11 +124,15 @@ class ProviderConfigEncrypter:
|
|||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
def create_provider_encrypter(
|
||||
tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
def create_tool_provider_encrypter(
|
||||
tenant_id: str, controller: ToolProviderController
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=controller.provider_type.value,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any, Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
"""add_pipeline_info_14
|
||||
|
||||
Revision ID: d3c68680d3ba
|
||||
Revises: fcb46171d891
|
||||
Create Date: 2025-07-21 12:20:29.582951
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd3c68680d3ba'
|
||||
down_revision = 'fcb46171d891'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('datasource_oauth_tenant_params',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('plugin_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('client_params', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('datasource_oauth_tenant_params')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -35,7 +35,25 @@ class DatasourceProvider(Base):
|
|||
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
||||
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True)
|
||||
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default")
|
||||
|
||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
|
||||
|
||||
class DatasourceOauthTenantParamConfig(Base):
|
||||
__tablename__ = "datasource_oauth_tenant_params"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
client_params: Mapped[dict] = db.Column(JSONB, nullable=False, default={})
|
||||
enabled: Mapped[bool] = db.Column(db.Boolean, nullable=False, default=False)
|
||||
|
||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,22 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper import encrypter
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.model_runtime.entities.provider_entities import FormType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.entities.plugin import DatasourceProviderID
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.oauth import DatasourceProvider
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -26,6 +29,165 @@ class DatasourceProviderService:
|
|||
def __init__(self) -> None:
|
||||
self.provider_manager = PluginDatasourceManager()
|
||||
|
||||
def setup_oauth_custom_client_params(
|
||||
self,
|
||||
tenant_id: str,
|
||||
datasource_provider_id: DatasourceProviderID,
|
||||
client_params: dict | None,
|
||||
enabled: bool | None,
|
||||
):
|
||||
"""
|
||||
setup oauth custom client params
|
||||
"""
|
||||
if client_params is None and enabled is None:
|
||||
return
|
||||
provider_controller = PluginDatasourceManager()
|
||||
datasource_provider = provider_controller.fetch_datasource_provider(
|
||||
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
|
||||
)
|
||||
if not datasource_provider.declaration.oauth_schema:
|
||||
raise ValueError("Datasource provider oauth schema not found")
|
||||
with Session(db.engine) as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tenant_oauth_client_params:
|
||||
tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
client_params={},
|
||||
enabled=False,
|
||||
)
|
||||
session.add(tenant_oauth_client_params)
|
||||
|
||||
if client_params is not None:
|
||||
client_schema = datasource_provider.declaration.oauth_schema.client_schema
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in client_schema],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
original_params = (
|
||||
encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
|
||||
)
|
||||
new_params: dict = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
|
||||
|
||||
if enabled is not None:
|
||||
tenant_oauth_client_params.enabled = enabled
|
||||
session.commit()
|
||||
|
||||
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
|
||||
"""
|
||||
check if system oauth params exist
|
||||
"""
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
return (
|
||||
session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
|
||||
"""
|
||||
check if tenant oauth params is enabled
|
||||
"""
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
return (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
)
|
||||
|
||||
def get_tenant_oauth_client(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
get tenant oauth client
|
||||
"""
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
||||
return None
|
||||
|
||||
def get_oauth_encrypter(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
"""
|
||||
get oauth encrypter
|
||||
"""
|
||||
datasource_provider = self.provider_manager.fetch_datasource_provider(
|
||||
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
|
||||
)
|
||||
if not datasource_provider.declaration.oauth_schema:
|
||||
raise ValueError("Datasource provider oauth schema not found")
|
||||
|
||||
client_schema = datasource_provider.declaration.oauth_schema.client_schema
|
||||
return create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in client_schema],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
|
||||
"""
|
||||
get oauth client
|
||||
"""
|
||||
provider = datasource_provider_id.provider_name
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
# get tenant oauth client params
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
||||
|
||||
# fallback to system oauth client params
|
||||
oauth_client_params = (
|
||||
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
)
|
||||
if oauth_client_params:
|
||||
return oauth_client_params.system_credentials
|
||||
|
||||
raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
|
||||
|
||||
@staticmethod
|
||||
def generate_next_datasource_provider_name(
|
||||
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
|
||||
|
|
@ -69,24 +231,29 @@ class DatasourceProviderService:
|
|||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
if session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=credential_type.value,
|
||||
).count() > 0:
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=credential_type.value,
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
db_provider_name = generate_incremental_name(
|
||||
[
|
||||
provider.name
|
||||
for provider in session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
)
|
||||
],
|
||||
db_provider_name,
|
||||
)
|
||||
[
|
||||
provider.name
|
||||
for provider in session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
)
|
||||
],
|
||||
db_provider_name,
|
||||
)
|
||||
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{provider_id}"
|
||||
|
|
@ -103,7 +270,7 @@ class DatasourceProviderService:
|
|||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=credential_type.value,
|
||||
encrypted_credentials=credentials,
|
||||
avatar_url=avatar_url,
|
||||
avatar_url=avatar_url or "default",
|
||||
)
|
||||
session.add(datasource_provider)
|
||||
session.commit()
|
||||
|
|
@ -222,6 +389,7 @@ class DatasourceProviderService:
|
|||
"credential": copy_credentials,
|
||||
"type": datasource_provider.auth_type,
|
||||
"name": datasource_provider.name,
|
||||
"avatar_url": datasource_provider.avatar_url,
|
||||
"id": datasource_provider.id,
|
||||
}
|
||||
)
|
||||
|
|
@ -239,6 +407,7 @@ class DatasourceProviderService:
|
|||
datasources = manager.fetch_installed_datasource_providers(tenant_id)
|
||||
datasource_credentials = []
|
||||
for datasource in datasources:
|
||||
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
|
||||
credentials = self.get_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
|
|
@ -302,6 +471,11 @@ class DatasourceProviderService:
|
|||
}
|
||||
for credential in datasource.declaration.oauth_schema.credentials_schema or []
|
||||
],
|
||||
"oauth_custom_client_params": self.get_tenant_oauth_client(tenant_id, datasource_provider_id),
|
||||
"is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
|
||||
tenant_id, datasource_provider_id
|
||||
),
|
||||
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
|
||||
}
|
||||
if datasource.declaration.oauth_schema
|
||||
else None,
|
||||
|
|
|
|||
Loading…
Reference in New Issue