feat: add datasource OAuth client setup command and refactor related models

This commit is contained in:
Harry 2025-07-18 14:11:08 +08:00
parent f153319a77
commit 6ca5bc1063
5 changed files with 142 additions and 3 deletions

View File

@ -12,7 +12,7 @@ from werkzeug.exceptions import NotFound
from configs import dify_config
from constants.languages import languages
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.entities.plugin import DatasourceProviderID, ToolProviderID
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
@ -29,6 +29,7 @@ from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.oauth import DatasourceOauthParamConfig
from models.provider import Provider, ProviderModel
from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService
@ -1205,3 +1206,49 @@ def setup_system_tool_oauth_client(provider, client_params):
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_datasource_oauth_client(provider, client_params):
"""
Setup datasource oauth client
"""
provider_id = DatasourceProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
deleted_count = (
db.session.query(DatasourceOauthParamConfig)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow"))
oauth_client = DatasourceOauthParamConfig(
provider=provider_name,
plugin_id=plugin_id,
system_credentials=client_params_dict,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"provider: {provider_name}", fg="green"))
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))

View File

@ -18,6 +18,7 @@ def init_app(app: DifyApp):
reset_email,
reset_encrypt_key_pair,
reset_password,
setup_datasource_oauth_client,
setup_system_tool_oauth_client,
upgrade_db,
vdb_migrate,
@ -42,6 +43,7 @@ def init_app(app: DifyApp):
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
setup_system_tool_oauth_client,
setup_datasource_oauth_client,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@ -0,0 +1,25 @@
"""empty message
Revision ID: bb3812d469dd
Revises: 15e40b74a6d2, 71f5020c6470
Create Date: 2025-07-18 14:09:12.778358
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'bb3812d469dd'
down_revision = ('15e40b74a6d2', '71f5020c6470')
branch_labels = None
depends_on = None
def upgrade():
pass
def downgrade():
pass

View File

@ -0,0 +1,65 @@
"""datasource_oauth_1
Revision ID: d4a76fde2724
Revises: bb3812d469dd
Create Date: 2025-07-18 14:09:42.551752
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'd4a76fde2724'
down_revision = 'bb3812d469dd'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tenant_plugin_auto_upgrade_strategies')
with op.batch_alter_table('datasource_oauth_params', schema=None) as batch_op:
batch_op.alter_column('plugin_id',
existing_type=sa.UUID(),
type_=sa.String(length=255),
existing_nullable=False)
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.alter_column('plugin_id',
existing_type=sa.TEXT(),
type_=sa.String(length=255),
existing_nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.alter_column('plugin_id',
existing_type=sa.String(length=255),
type_=sa.TEXT(),
existing_nullable=False)
with op.batch_alter_table('datasource_oauth_params', schema=None) as batch_op:
batch_op.alter_column('plugin_id',
existing_type=sa.String(length=255),
type_=sa.UUID(),
existing_nullable=False)
op.create_table('tenant_plugin_auto_upgrade_strategies',
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
sa.Column('strategy_setting', sa.VARCHAR(length=16), server_default=sa.text("'fix_only'::character varying"), autoincrement=False, nullable=False),
sa.Column('upgrade_time_of_day', sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column('upgrade_mode', sa.VARCHAR(length=16), server_default=sa.text("'exclude'::character varying"), autoincrement=False, nullable=False),
sa.Column('exclude_plugins', postgresql.ARRAY(sa.VARCHAR(length=255)), autoincrement=False, nullable=False),
sa.Column('include_plugins', postgresql.ARRAY(sa.VARCHAR(length=255)), autoincrement=False, nullable=False),
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=False),
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('id', name=op.f('tenant_plugin_auto_upgrade_strategy_pkey')),
sa.UniqueConstraint('tenant_id', name=op.f('unique_tenant_plugin_auto_upgrade_strategy'), postgresql_include=[], postgresql_nulls_not_distinct=False)
)
# ### end Alembic commands ###

View File

@ -16,7 +16,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False)
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
@ -31,7 +31,7 @@ class DatasourceProvider(Base):
tenant_id = db.Column(StringUUID, nullable=False)
name: Mapped[str] = db.Column(db.String(255), nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False)
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)