From af94602d37a630e74518d5d22fbf82bbb5e62653 Mon Sep 17 00:00:00 2001 From: Harry Date: Mon, 21 Jul 2025 15:49:26 +0800 Subject: [PATCH] feat: add APIs for setting default datasource provider and updating provider name --- .../datasets/rag_pipeline/datasource_auth.py | 53 +++++++ .../nodes/datasource/datasource_node.py | 6 +- ..._1523-74e5f667f4b7_add_pipeline_info_15.py | 33 +++++ api/models/oauth.py | 1 + api/services/datasource_provider_service.py | 135 ++++++++++++++++-- 5 files changed, 212 insertions(+), 16 deletions(-) create mode 100644 api/migrations/versions/2025_07_21_1523-74e5f667f4b7_add_pipeline_info_15.py diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 7442a5001a..d1e4812bcd 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -205,6 +205,7 @@ class DatasourceAuthListApi(Resource): ) return {"result": jsonable_encoder(datasources)}, 200 + class DatasourceAuthOauthCustomClient(Resource): @setup_required @login_required @@ -227,6 +228,48 @@ class DatasourceAuthOauthCustomClient(Resource): ) return {"result": "success"}, 200 + +class DatasourceAuthDefaultApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider_id: str): + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + datasource_provider_id = DatasourceProviderID(provider_id) + datasource_provider_service = DatasourceProviderService() + datasource_provider_service.set_default_datasource_provider( + tenant_id=current_user.current_tenant_id, + datasource_provider_id=datasource_provider_id, + credential_id=args["credential_id"], + ) + return {"result": "success"}, 200 + +class DatasourceUpdateProviderNameApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider_id: str): + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + datasource_provider_id = DatasourceProviderID(provider_id) + datasource_provider_service = DatasourceProviderService() + datasource_provider_service.update_datasource_provider_name( + tenant_id=current_user.current_tenant_id, + datasource_provider_id=datasource_provider_id, + name=args["name"], + credential_id=args["credential_id"], + ) + return {"result": "success"}, 200 + + api.add_resource( DatasourcePluginOAuthAuthorizationUrl, "/oauth/plugin//datasource/get-authorization-url", @@ -254,3 +297,13 @@ api.add_resource( DatasourceAuthOauthCustomClient, "/auth/plugin/datasource//custom-client", ) + +api.add_resource( + DatasourceAuthDefaultApi, + "/auth/plugin/datasource//default", +) + +api.add_resource( + DatasourceUpdateProviderNameApi, + "/auth/plugin/datasource//update-name", +) \ No newline at end of file diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 9e271b8c16..c74bf2a86e 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -127,13 +127,13 @@ class DatasourceNode(BaseNode): case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_real_datasource_credentials( + credentials = datasource_provider_service.get_default_credentials( tenant_id=self.tenant_id, provider=node_data.provider_name, plugin_id=node_data.plugin_id, ) if credentials: - datasource_runtime.runtime.credentials = credentials[0].get("credentials") + datasource_runtime.runtime.credentials = credentials online_document_result: Generator[DatasourceMessage, None, None] = ( datasource_runtime.get_online_document_page_content( user_id=self.user_id, @@ -159,7 +159,7 @@ class DatasourceNode(BaseNode): plugin_id=node_data.plugin_id, ) if credentials: - datasource_runtime.runtime.credentials = credentials[0].get("credentials") + datasource_runtime.runtime.credentials = credentials online_drive_result: Generator[DatasourceMessage, None, None] = ( datasource_runtime.online_drive_download_file( user_id=self.user_id, diff --git a/api/migrations/versions/2025_07_21_1523-74e5f667f4b7_add_pipeline_info_15.py b/api/migrations/versions/2025_07_21_1523-74e5f667f4b7_add_pipeline_info_15.py new file mode 100644 index 0000000000..79ec4eb4da --- /dev/null +++ b/api/migrations/versions/2025_07_21_1523-74e5f667f4b7_add_pipeline_info_15.py @@ -0,0 +1,33 @@ +"""add_pipeline_info_15 + +Revision ID: 74e5f667f4b7 +Revises: d3c68680d3ba +Create Date: 2025-07-21 15:23:20.825747 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '74e5f667f4b7' +down_revision = 'd3c68680d3ba' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.drop_column('is_default') + + # ### end Alembic commands ### diff --git a/api/models/oauth.py b/api/models/oauth.py index 65bcb5c0a3..8e661051a7 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -36,6 +36,7 @@ class DatasourceProvider(Base): auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default") + is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 513d7366c1..a904215823 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -29,6 +29,96 @@ class DatasourceProviderService: def __init__(self) -> None: self.provider_manager = PluginDatasourceManager() + def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]: + """ + get default credentials + """ + with Session(db.engine) as session: + datasource_provider = ( + session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) + .first() + ) + if not datasource_provider: + return {} + return datasource_provider.encrypted_credentials + + def update_datasource_provider_name( + self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str + ): + """ + update datasource provider name + """ + with Session(db.engine) as session: + target_provider = ( + session.query(DatasourceProvider) + .filter_by( + tenant_id=tenant_id, + id=credential_id, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, + ) + .first() + ) + if target_provider is None: + raise ValueError("provider not found") + + if target_provider.name == name: + return + + # check name is exist + if ( + session.query(DatasourceProvider) + .filter_by( + tenant_id=tenant_id, + name=name, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, + ) + .count() + > 0 + ): + raise ValueError("name is already exists") + + target_provider.name = name + session.commit() + return + + def set_default_datasource_provider( + self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str + ): + """ + set default datasource provider + """ + with Session(db.engine) as session: + # get provider + target_provider = ( + session.query(DatasourceProvider) + .filter_by( + tenant_id=tenant_id, + id=credential_id, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, + ) + .first() + ) + if target_provider is None: + raise ValueError("provider not found") + + # clear default provider + session.query(DatasourceProvider).filter_by( + tenant_id=tenant_id, + provider=target_provider.provider, + plugin_id=target_provider.plugin_id, + is_default=True, + ).update({"is_default": False}) + + # set new default provider + target_provider.is_default = True + session.commit() + return {"result": "success"} + def setup_oauth_custom_client_params( self, tenant_id: str, @@ -41,10 +131,6 @@ class DatasourceProviderService: """ if client_params is None and enabled is None: return - provider_controller = PluginDatasourceManager() - datasource_provider = provider_controller.fetch_datasource_provider( - tenant_id=tenant_id, provider_id=str(datasource_provider_id) - ) with Session(db.engine) as session: tenant_oauth_client_params = ( session.query(DatasourceOauthTenantParamConfig) @@ -252,7 +338,7 @@ class DatasourceProviderService: ) provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, provider_id=f"{provider_id}" + tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type.value ) for key, value in credentials.items(): if key in provider_credential_secret_variables: @@ -310,7 +396,7 @@ class DatasourceProviderService: ) if credential_valid: provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, provider_id=f"{provider_id}" + tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY.value ) for key, value in credentials.items(): if key in provider_credential_secret_variables: @@ -329,7 +415,7 @@ class DatasourceProviderService: else: raise CredentialsValidateFailedError() - def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]: + def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: str) -> list[str]: """ Extract secret input form variables. @@ -339,7 +425,16 @@ class DatasourceProviderService: datasource_provider = self.provider_manager.fetch_datasource_provider( tenant_id=tenant_id, provider_id=provider_id ) - credential_form_schemas = datasource_provider.declaration.credentials_schema + credential_form_schemas = [] + if credential_type == "api_key": + credential_form_schemas = datasource_provider.declaration.credentials_schema + elif credential_type == "oauth2": + if not datasource_provider.declaration.oauth_schema: + raise ValueError("Datasource provider oauth schema not found") + credential_form_schemas = datasource_provider.declaration.oauth_schema.credentials_schema + else: + raise ValueError(f"Invalid credential type: {credential_type}") + secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: if credential_form_schema.type.value == FormType.SECRET_INPUT.value: @@ -368,11 +463,20 @@ class DatasourceProviderService: if not datasource_providers: return [] copy_credentials_list = [] + default_provider = ( + db.session.query(DatasourceProvider.id) + .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) + .first() + ) + default_provider_id = default_provider.id if default_provider else None for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}", + credential_type=datasource_provider.auth_type, ) # Obfuscate provider credentials @@ -387,6 +491,7 @@ class DatasourceProviderService: "name": datasource_provider.name, "avatar_url": datasource_provider.avatar_url, "id": datasource_provider.id, + "is_default": default_provider_id and datasource_provider.id == default_provider_id, } ) @@ -469,7 +574,9 @@ class DatasourceProviderService: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}", + credential_type=datasource_provider.auth_type, ) # Obfuscate provider credentials @@ -507,12 +614,14 @@ class DatasourceProviderService: .first() ) - provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" - ) if not datasource_provider: raise ValueError("Datasource provider not found") else: + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}", + credential_type=datasource_provider.auth_type, + ) original_credentials = datasource_provider.encrypted_credentials for key, value in credentials.items(): if key in provider_credential_secret_variables: