mirror of https://github.com/langgenius/dify.git
feat: add datasource OAuth client setup command and refactor related models
This commit is contained in:
parent
f153319a77
commit
6ca5bc1063
|
|
@ -12,7 +12,7 @@ from werkzeug.exceptions import NotFound
|
|||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.entities.plugin import DatasourceProviderID, ToolProviderID
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
|
|
@ -29,6 +29,7 @@ from models import Tenant
|
|||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.oauth import DatasourceOauthParamConfig
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
|
|
@ -1205,3 +1206,49 @@ def setup_system_tool_oauth_client(provider, client_params):
|
|||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_datasource_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup datasource oauth client
|
||||
"""
|
||||
provider_id = DatasourceProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||
deleted_count = (
|
||||
db.session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow"))
|
||||
oauth_client = DatasourceOauthParamConfig(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
system_credentials=client_params_dict,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"provider: {provider_name}", fg="green"))
|
||||
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
|
||||
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
|
||||
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ def init_app(app: DifyApp):
|
|||
reset_email,
|
||||
reset_encrypt_key_pair,
|
||||
reset_password,
|
||||
setup_datasource_oauth_client,
|
||||
setup_system_tool_oauth_client,
|
||||
upgrade_db,
|
||||
vdb_migrate,
|
||||
|
|
@ -42,6 +43,7 @@ def init_app(app: DifyApp):
|
|||
clear_orphaned_file_records,
|
||||
remove_orphaned_files_on_storage,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_datasource_oauth_client,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
"""empty message
|
||||
|
||||
Revision ID: bb3812d469dd
|
||||
Revises: 15e40b74a6d2, 71f5020c6470
|
||||
Create Date: 2025-07-18 14:09:12.778358
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'bb3812d469dd'
|
||||
down_revision = ('15e40b74a6d2', '71f5020c6470')
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
pass
|
||||
|
||||
|
||||
def downgrade():
|
||||
pass
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
"""datasource_oauth_1
|
||||
|
||||
Revision ID: d4a76fde2724
|
||||
Revises: bb3812d469dd
|
||||
Create Date: 2025-07-18 14:09:42.551752
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd4a76fde2724'
|
||||
down_revision = 'bb3812d469dd'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('tenant_plugin_auto_upgrade_strategies')
|
||||
with op.batch_alter_table('datasource_oauth_params', schema=None) as batch_op:
|
||||
batch_op.alter_column('plugin_id',
|
||||
existing_type=sa.UUID(),
|
||||
type_=sa.String(length=255),
|
||||
existing_nullable=False)
|
||||
|
||||
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('plugin_id',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.String(length=255),
|
||||
existing_nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('plugin_id',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.TEXT(),
|
||||
existing_nullable=False)
|
||||
|
||||
with op.batch_alter_table('datasource_oauth_params', schema=None) as batch_op:
|
||||
batch_op.alter_column('plugin_id',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.UUID(),
|
||||
existing_nullable=False)
|
||||
|
||||
op.create_table('tenant_plugin_auto_upgrade_strategies',
|
||||
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column('strategy_setting', sa.VARCHAR(length=16), server_default=sa.text("'fix_only'::character varying"), autoincrement=False, nullable=False),
|
||||
sa.Column('upgrade_time_of_day', sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column('upgrade_mode', sa.VARCHAR(length=16), server_default=sa.text("'exclude'::character varying"), autoincrement=False, nullable=False),
|
||||
sa.Column('exclude_plugins', postgresql.ARRAY(sa.VARCHAR(length=255)), autoincrement=False, nullable=False),
|
||||
sa.Column('include_plugins', postgresql.ARRAY(sa.VARCHAR(length=255)), autoincrement=False, nullable=False),
|
||||
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=False),
|
||||
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('tenant_plugin_auto_upgrade_strategy_pkey')),
|
||||
sa.UniqueConstraint('tenant_id', name=op.f('unique_tenant_plugin_auto_upgrade_strategy'), postgresql_include=[], postgresql_nulls_not_distinct=False)
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -16,7 +16,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
|
|||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ class DatasourceProvider(Base):
|
|||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False)
|
||||
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
|
|
|
|||
Loading…
Reference in New Issue