From 6ca5bc1063c242a21e4186af009c3a0ff6f567f7 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 18 Jul 2025 14:11:08 +0800 Subject: [PATCH] feat: add datasource OAuth client setup command and refactor related models --- api/commands.py | 49 +++++++++++++- api/extensions/ext_commands.py | 2 + ..._07_18_1409-bb3812d469dd_merge_rag_main.py | 25 +++++++ ...18_1409-d4a76fde2724_datasource_oauth_1.py | 65 +++++++++++++++++++ api/models/oauth.py | 4 +- 5 files changed, 142 insertions(+), 3 deletions(-) create mode 100644 api/migrations/versions/2025_07_18_1409-bb3812d469dd_merge_rag_main.py create mode 100644 api/migrations/versions/2025_07_18_1409-d4a76fde2724_datasource_oauth_1.py diff --git a/api/commands.py b/api/commands.py index 9f933a378c..eec5dda26d 100644 --- a/api/commands.py +++ b/api/commands.py @@ -12,7 +12,7 @@ from werkzeug.exceptions import NotFound from configs import dify_config from constants.languages import languages -from core.plugin.entities.plugin import ToolProviderID +from core.plugin.entities.plugin import DatasourceProviderID, ToolProviderID from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField @@ -29,6 +29,7 @@ from models import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation +from models.oauth import DatasourceOauthParamConfig from models.provider import Provider, ProviderModel from models.tools import ToolOAuthSystemClient from services.account_service import AccountService, RegisterService, TenantService @@ -1205,3 +1206,49 @@ def setup_system_tool_oauth_client(provider, client_params): db.session.add(oauth_client) db.session.commit() click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) + + +@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.") +@click.option("--provider", prompt=True, help="Provider name") +@click.option("--client-params", prompt=True, help="Client Params") +def setup_datasource_oauth_client(provider, client_params): + """ + Setup datasource oauth client + """ + provider_id = DatasourceProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + + try: + # json validate + click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) + click.echo(click.style("Client params validated successfully.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + + click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow")) + deleted_count = ( + db.session.query(DatasourceOauthParamConfig) + .filter_by( + provider=provider_name, + plugin_id=plugin_id, + ) + .delete() + ) + if deleted_count > 0: + click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) + + click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow")) + oauth_client = DatasourceOauthParamConfig( + provider=provider_name, + plugin_id=plugin_id, + system_credentials=client_params_dict, + ) + db.session.add(oauth_client) + db.session.commit() + click.echo(click.style(f"provider: {provider_name}", fg="green")) + click.echo(click.style(f"plugin_id: {plugin_id}", fg="green")) + click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green")) + click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green")) diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 600e336c19..43353b9210 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -18,6 +18,7 @@ def init_app(app: DifyApp): reset_email, reset_encrypt_key_pair, reset_password, + setup_datasource_oauth_client, setup_system_tool_oauth_client, upgrade_db, vdb_migrate, @@ -42,6 +43,7 @@ def init_app(app: DifyApp): clear_orphaned_file_records, remove_orphaned_files_on_storage, setup_system_tool_oauth_client, + setup_datasource_oauth_client, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/migrations/versions/2025_07_18_1409-bb3812d469dd_merge_rag_main.py b/api/migrations/versions/2025_07_18_1409-bb3812d469dd_merge_rag_main.py new file mode 100644 index 0000000000..571ef00929 --- /dev/null +++ b/api/migrations/versions/2025_07_18_1409-bb3812d469dd_merge_rag_main.py @@ -0,0 +1,25 @@ +"""empty message + +Revision ID: bb3812d469dd +Revises: 15e40b74a6d2, 71f5020c6470 +Create Date: 2025-07-18 14:09:12.778358 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'bb3812d469dd' +down_revision = ('15e40b74a6d2', '71f5020c6470') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/api/migrations/versions/2025_07_18_1409-d4a76fde2724_datasource_oauth_1.py b/api/migrations/versions/2025_07_18_1409-d4a76fde2724_datasource_oauth_1.py new file mode 100644 index 0000000000..c6fd40e2bb --- /dev/null +++ b/api/migrations/versions/2025_07_18_1409-d4a76fde2724_datasource_oauth_1.py @@ -0,0 +1,65 @@ +"""datasource_oauth_1 + +Revision ID: d4a76fde2724 +Revises: bb3812d469dd +Create Date: 2025-07-18 14:09:42.551752 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'd4a76fde2724' +down_revision = 'bb3812d469dd' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tenant_plugin_auto_upgrade_strategies') + with op.batch_alter_table('datasource_oauth_params', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.UUID(), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.TEXT(), + type_=sa.String(length=255), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.String(length=255), + type_=sa.TEXT(), + existing_nullable=False) + + with op.batch_alter_table('datasource_oauth_params', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.String(length=255), + type_=sa.UUID(), + existing_nullable=False) + + op.create_table('tenant_plugin_auto_upgrade_strategies', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), + sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('strategy_setting', sa.VARCHAR(length=16), server_default=sa.text("'fix_only'::character varying"), autoincrement=False, nullable=False), + sa.Column('upgrade_time_of_day', sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column('upgrade_mode', sa.VARCHAR(length=16), server_default=sa.text("'exclude'::character varying"), autoincrement=False, nullable=False), + sa.Column('exclude_plugins', postgresql.ARRAY(sa.VARCHAR(length=255)), autoincrement=False, nullable=False), + sa.Column('include_plugins', postgresql.ARRAY(sa.VARCHAR(length=255)), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=False), + sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id', name=op.f('tenant_plugin_auto_upgrade_strategy_pkey')), + sa.UniqueConstraint('tenant_id', name=op.f('unique_tenant_plugin_auto_upgrade_strategy'), postgresql_include=[], postgresql_nulls_not_distinct=False) + ) + # ### end Alembic commands ### diff --git a/api/models/oauth.py b/api/models/oauth.py index 84bc29931e..ed839e4fb1 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -16,7 +16,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) + plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) @@ -31,7 +31,7 @@ class DatasourceProvider(Base): tenant_id = db.Column(StringUUID, nullable=False) name: Mapped[str] = db.Column(db.String(255), nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) - plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False) + plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)