mirror of https://github.com/langgenius/dify.git
feat: add APIs for setting default datasource provider and updating provider name
This commit is contained in:
parent
9c96f1db6c
commit
af94602d37
|
|
@ -205,6 +205,7 @@ class DatasourceAuthListApi(Resource):
|
|||
)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -227,6 +228,48 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
|
||||
|
|
@ -254,3 +297,13 @@ api.add_resource(
|
|||
DatasourceAuthOauthCustomClient,
|
||||
"/auth/plugin/datasource/<path:provider_id>/custom-client",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDefaultApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/default",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceUpdateProviderNameApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update-name",
|
||||
)
|
||||
|
|
@ -127,13 +127,13 @@ class DatasourceNode(BaseNode):
|
|||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_real_datasource_credentials(
|
||||
credentials = datasource_provider_service.get_default_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_document_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
|
|
@ -159,7 +159,7 @@ class DatasourceNode(BaseNode):
|
|||
plugin_id=node_data.plugin_id,
|
||||
)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.online_drive_download_file(
|
||||
user_id=self.user_id,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
"""add_pipeline_info_15
|
||||
|
||||
Revision ID: 74e5f667f4b7
|
||||
Revises: d3c68680d3ba
|
||||
Create Date: 2025-07-21 15:23:20.825747
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '74e5f667f4b7'
|
||||
down_revision = 'd3c68680d3ba'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('is_default')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -36,6 +36,7 @@ class DatasourceProvider(Base):
|
|||
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
||||
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default")
|
||||
is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
|
||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,96 @@ class DatasourceProviderService:
|
|||
def __init__(self) -> None:
|
||||
self.provider_manager = PluginDatasourceManager()
|
||||
|
||||
def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
get default credentials
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
if not datasource_provider:
|
||||
return {}
|
||||
return datasource_provider.encrypted_credentials
|
||||
|
||||
def update_datasource_provider_name(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
|
||||
):
|
||||
"""
|
||||
update datasource provider name
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
id=credential_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
if target_provider.name == name:
|
||||
return
|
||||
|
||||
# check name is exist
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
raise ValueError("name is already exists")
|
||||
|
||||
target_provider.name = name
|
||||
session.commit()
|
||||
return
|
||||
|
||||
def set_default_datasource_provider(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
|
||||
):
|
||||
"""
|
||||
set default datasource provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# get provider
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
id=credential_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# clear default provider
|
||||
session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=target_provider.provider,
|
||||
plugin_id=target_provider.plugin_id,
|
||||
is_default=True,
|
||||
).update({"is_default": False})
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
def setup_oauth_custom_client_params(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
|
@ -41,10 +131,6 @@ class DatasourceProviderService:
|
|||
"""
|
||||
if client_params is None and enabled is None:
|
||||
return
|
||||
provider_controller = PluginDatasourceManager()
|
||||
datasource_provider = provider_controller.fetch_datasource_provider(
|
||||
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
|
|
@ -252,7 +338,7 @@ class DatasourceProviderService:
|
|||
)
|
||||
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{provider_id}"
|
||||
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type.value
|
||||
)
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
|
|
@ -310,7 +396,7 @@ class DatasourceProviderService:
|
|||
)
|
||||
if credential_valid:
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{provider_id}"
|
||||
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY.value
|
||||
)
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
|
|
@ -329,7 +415,7 @@ class DatasourceProviderService:
|
|||
else:
|
||||
raise CredentialsValidateFailedError()
|
||||
|
||||
def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]:
|
||||
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: str) -> list[str]:
|
||||
"""
|
||||
Extract secret input form variables.
|
||||
|
||||
|
|
@ -339,7 +425,16 @@ class DatasourceProviderService:
|
|||
datasource_provider = self.provider_manager.fetch_datasource_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
credential_form_schemas = datasource_provider.declaration.credentials_schema
|
||||
credential_form_schemas = []
|
||||
if credential_type == "api_key":
|
||||
credential_form_schemas = datasource_provider.declaration.credentials_schema
|
||||
elif credential_type == "oauth2":
|
||||
if not datasource_provider.declaration.oauth_schema:
|
||||
raise ValueError("Datasource provider oauth schema not found")
|
||||
credential_form_schemas = datasource_provider.declaration.oauth_schema.credentials_schema
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
|
||||
|
|
@ -368,11 +463,20 @@ class DatasourceProviderService:
|
|||
if not datasource_providers:
|
||||
return []
|
||||
copy_credentials_list = []
|
||||
default_provider = (
|
||||
db.session.query(DatasourceProvider.id)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
default_provider_id = default_provider.id if default_provider else None
|
||||
for datasource_provider in datasource_providers:
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=datasource_provider.auth_type,
|
||||
)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
|
|
@ -387,6 +491,7 @@ class DatasourceProviderService:
|
|||
"name": datasource_provider.name,
|
||||
"avatar_url": datasource_provider.avatar_url,
|
||||
"id": datasource_provider.id,
|
||||
"is_default": default_provider_id and datasource_provider.id == default_provider_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -469,7 +574,9 @@ class DatasourceProviderService:
|
|||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=datasource_provider.auth_type,
|
||||
)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
|
|
@ -507,12 +614,14 @@ class DatasourceProviderService:
|
|||
.first()
|
||||
)
|
||||
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
|
||||
)
|
||||
if not datasource_provider:
|
||||
raise ValueError("Datasource provider not found")
|
||||
else:
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=datasource_provider.auth_type,
|
||||
)
|
||||
original_credentials = datasource_provider.encrypted_credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
|
|
|
|||
Loading…
Reference in New Issue