feat: add custom OAuth client setup and enhance datasource provider model with avatar_url

This commit is contained in:
Harry 2025-07-21 12:35:07 +08:00
parent 7364d051d2
commit e97f03c130
7 changed files with 294 additions and 33 deletions

View File

@ -81,12 +81,13 @@ class DatasourceOAuthCallback(Resource):
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
datasource_provider_id = DatasourceProviderID(provider_id)
provider_name = datasource_provider_id.provider_name
plugin_id = datasource_provider_id.plugin_id
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first()
datasource_provider_service = DatasourceProviderService()
oauth_client_params = datasource_provider_service.get_oauth_client(
tenant_id=tenant_id,
datasource_provider_id=datasource_provider_id,
)
if not plugin_oauth_config:
if not oauth_client_params:
raise NotFound()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
oauth_handler = OAuthHandler()
@ -96,10 +97,9 @@ class DatasourceOAuthCallback(Resource):
plugin_id=plugin_id,
provider=datasource_provider_id.provider_name,
redirect_uri=redirect_uri,
system_credentials=plugin_oauth_config.system_credentials,
system_credentials=oauth_client_params,
request=request,
)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.add_datasource_oauth_provider(
tenant_id=tenant_id,
provider_id=datasource_provider_id,
@ -205,8 +205,28 @@ class DatasourceAuthListApi(Resource):
)
return {"result": datasources}, 200
class DatasourceAuthOauthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}),
enabled=args.get("enabled", False),
)
return {"result": "success"}, 200
# Import Rag Pipeline
api.add_resource(
DatasourcePluginOAuthAuthorizationUrl,
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
@ -229,3 +249,8 @@ api.add_resource(
DatasourceAuthListApi,
"/auth/plugin/datasource/list",
)
api.add_resource(
DatasourceAuthOauthCustomClient,
"/auth/plugin/datasource/<path:provider_id>/custom-client",
)

View File

@ -361,4 +361,4 @@ class PluginDatasourceManager(BasePluginClient):
}
],
},
}
}

View File

@ -124,11 +124,15 @@ class ProviderConfigEncrypter:
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
def create_provider_encrypter(
tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
def create_tool_provider_encrypter(
tenant_id: str, controller: ToolProviderController
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,
provider_type=controller.provider_type.value,

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Any, Mapping
from collections.abc import Mapping, Sequence
from typing import Any
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator

View File

@ -0,0 +1,40 @@
"""add_pipeline_info_14
Revision ID: d3c68680d3ba
Revises: fcb46171d891
Create Date: 2025-07-21 12:20:29.582951
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'd3c68680d3ba'
down_revision = 'fcb46171d891'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('datasource_oauth_tenant_params',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('plugin_id', sa.String(length=255), nullable=False),
sa.Column('client_params', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('enabled', sa.Boolean(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('datasource_oauth_tenant_params')
# ### end Alembic commands ###

View File

@ -35,7 +35,25 @@ class DatasourceProvider(Base):
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True)
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default")
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
class DatasourceOauthTenantParamConfig(Base):
__tablename__ = "datasource_oauth_tenant_params"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
client_params: Mapped[dict] = db.Column(JSONB, nullable=False, default={})
enabled: Mapped[bool] = db.Column(db.Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)

View File

@ -1,19 +1,22 @@
import logging
from typing import Any
from flask_login import current_user
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper import encrypter
from core.helper.name_generator import generate_incremental_name
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.model_runtime.entities.provider_entities import FormType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin import DatasourceProviderID
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.oauth import DatasourceProvider
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
logger = logging.getLogger(__name__)
@ -26,6 +29,165 @@ class DatasourceProviderService:
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()
def setup_oauth_custom_client_params(
self,
tenant_id: str,
datasource_provider_id: DatasourceProviderID,
client_params: dict | None,
enabled: bool | None,
):
"""
setup oauth custom client params
"""
if client_params is None and enabled is None:
return
provider_controller = PluginDatasourceManager()
datasource_provider = provider_controller.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
)
if not datasource_provider.declaration.oauth_schema:
raise ValueError("Datasource provider oauth schema not found")
with Session(db.engine) as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if not tenant_oauth_client_params:
tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
client_params={},
enabled=False,
)
session.add(tenant_oauth_client_params)
if client_params is not None:
client_schema = datasource_provider.declaration.oauth_schema.client_schema
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in client_schema],
cache=NoOpProviderCredentialCache(),
)
original_params = (
encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
)
new_params: dict = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
if enabled is not None:
tenant_oauth_client_params.enabled = enabled
session.commit()
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
"""
check if system oauth params exist
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthParamConfig)
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
.first()
is not None
)
def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
"""
check if tenant oauth params is enabled
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
enabled=True,
)
.count()
> 0
)
def get_tenant_oauth_client(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
) -> dict[str, Any] | None:
"""
get tenant oauth client
"""
with Session(db.engine).no_autoflush as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
return encrypter.decrypt(tenant_oauth_client_params.client_params)
return None
def get_oauth_encrypter(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
"""
get oauth encrypter
"""
datasource_provider = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
)
if not datasource_provider.declaration.oauth_schema:
raise ValueError("Datasource provider oauth schema not found")
client_schema = datasource_provider.declaration.oauth_schema.client_schema
return create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in client_schema],
cache=NoOpProviderCredentialCache(),
)
def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
"""
get oauth client
"""
provider = datasource_provider_id.provider_name
plugin_id = datasource_provider_id.plugin_id
with Session(db.engine).no_autoflush as session:
# get tenant oauth client params
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
)
.first()
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
return encrypter.decrypt(tenant_oauth_client_params.client_params)
# fallback to system oauth client params
oauth_client_params = (
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if oauth_client_params:
return oauth_client_params.system_credentials
raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
@staticmethod
def generate_next_datasource_provider_name(
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
@ -69,24 +231,29 @@ class DatasourceProviderService:
credential_type=credential_type,
)
else:
if session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
).count() > 0:
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
)
.count()
> 0
):
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
],
db_provider_name,
)
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
],
db_provider_name,
)
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}"
@ -103,7 +270,7 @@ class DatasourceProviderService:
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
encrypted_credentials=credentials,
avatar_url=avatar_url,
avatar_url=avatar_url or "default",
)
session.add(datasource_provider)
session.commit()
@ -222,6 +389,7 @@ class DatasourceProviderService:
"credential": copy_credentials,
"type": datasource_provider.auth_type,
"name": datasource_provider.name,
"avatar_url": datasource_provider.avatar_url,
"id": datasource_provider.id,
}
)
@ -239,6 +407,7 @@ class DatasourceProviderService:
datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = []
for datasource in datasources:
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
credentials = self.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
@ -302,6 +471,11 @@ class DatasourceProviderService:
}
for credential in datasource.declaration.oauth_schema.credentials_schema or []
],
"oauth_custom_client_params": self.get_tenant_oauth_client(tenant_id, datasource_provider_id),
"is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
tenant_id, datasource_provider_id
),
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
}
if datasource.declaration.oauth_schema
else None,