mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 20:48:01 +08:00
feat: add APIs for setting default datasource provider and updating provider name
This commit is contained in:
parent
9c96f1db6c
commit
af94602d37
@ -205,6 +205,7 @@ class DatasourceAuthListApi(Resource):
|
|||||||
)
|
)
|
||||||
return {"result": jsonable_encoder(datasources)}, 200
|
return {"result": jsonable_encoder(datasources)}, 200
|
||||||
|
|
||||||
|
|
||||||
class DatasourceAuthOauthCustomClient(Resource):
|
class DatasourceAuthOauthCustomClient(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -227,6 +228,48 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||||||
)
|
)
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceAuthDefaultApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider_id: str):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.set_default_datasource_provider(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
datasource_provider_id=datasource_provider_id,
|
||||||
|
credential_id=args["credential_id"],
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
class DatasourceUpdateProviderNameApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider_id: str):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.update_datasource_provider_name(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
datasource_provider_id=datasource_provider_id,
|
||||||
|
name=args["name"],
|
||||||
|
credential_id=args["credential_id"],
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
DatasourcePluginOAuthAuthorizationUrl,
|
DatasourcePluginOAuthAuthorizationUrl,
|
||||||
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
|
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
|
||||||
@ -254,3 +297,13 @@ api.add_resource(
|
|||||||
DatasourceAuthOauthCustomClient,
|
DatasourceAuthOauthCustomClient,
|
||||||
"/auth/plugin/datasource/<path:provider_id>/custom-client",
|
"/auth/plugin/datasource/<path:provider_id>/custom-client",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceAuthDefaultApi,
|
||||||
|
"/auth/plugin/datasource/<path:provider_id>/default",
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceUpdateProviderNameApi,
|
||||||
|
"/auth/plugin/datasource/<path:provider_id>/update-name",
|
||||||
|
)
|
||||||
@ -127,13 +127,13 @@ class DatasourceNode(BaseNode):
|
|||||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
credentials = datasource_provider_service.get_real_datasource_credentials(
|
credentials = datasource_provider_service.get_default_credentials(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
provider=node_data.provider_name,
|
provider=node_data.provider_name,
|
||||||
plugin_id=node_data.plugin_id,
|
plugin_id=node_data.plugin_id,
|
||||||
)
|
)
|
||||||
if credentials:
|
if credentials:
|
||||||
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
datasource_runtime.runtime.credentials = credentials
|
||||||
online_document_result: Generator[DatasourceMessage, None, None] = (
|
online_document_result: Generator[DatasourceMessage, None, None] = (
|
||||||
datasource_runtime.get_online_document_page_content(
|
datasource_runtime.get_online_document_page_content(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
@ -159,7 +159,7 @@ class DatasourceNode(BaseNode):
|
|||||||
plugin_id=node_data.plugin_id,
|
plugin_id=node_data.plugin_id,
|
||||||
)
|
)
|
||||||
if credentials:
|
if credentials:
|
||||||
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
datasource_runtime.runtime.credentials = credentials
|
||||||
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
||||||
datasource_runtime.online_drive_download_file(
|
datasource_runtime.online_drive_download_file(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
|
|||||||
@ -0,0 +1,33 @@
|
|||||||
|
"""add_pipeline_info_15
|
||||||
|
|
||||||
|
Revision ID: 74e5f667f4b7
|
||||||
|
Revises: d3c68680d3ba
|
||||||
|
Create Date: 2025-07-21 15:23:20.825747
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '74e5f667f4b7'
|
||||||
|
down_revision = 'd3c68680d3ba'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('is_default')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -36,6 +36,7 @@ class DatasourceProvider(Base):
|
|||||||
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||||
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
||||||
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default")
|
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default")
|
||||||
|
is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||||
|
|
||||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||||
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||||
|
|||||||
@ -29,6 +29,96 @@ class DatasourceProviderService:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.provider_manager = PluginDatasourceManager()
|
self.provider_manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
get default credentials
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
datasource_provider = (
|
||||||
|
session.query(DatasourceProvider)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||||
|
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not datasource_provider:
|
||||||
|
return {}
|
||||||
|
return datasource_provider.encrypted_credentials
|
||||||
|
|
||||||
|
def update_datasource_provider_name(
|
||||||
|
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
update datasource provider name
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
target_provider = (
|
||||||
|
session.query(DatasourceProvider)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
id=credential_id,
|
||||||
|
provider=datasource_provider_id.provider_name,
|
||||||
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if target_provider is None:
|
||||||
|
raise ValueError("provider not found")
|
||||||
|
|
||||||
|
if target_provider.name == name:
|
||||||
|
return
|
||||||
|
|
||||||
|
# check name is exist
|
||||||
|
if (
|
||||||
|
session.query(DatasourceProvider)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
name=name,
|
||||||
|
provider=datasource_provider_id.provider_name,
|
||||||
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
> 0
|
||||||
|
):
|
||||||
|
raise ValueError("name is already exists")
|
||||||
|
|
||||||
|
target_provider.name = name
|
||||||
|
session.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
def set_default_datasource_provider(
|
||||||
|
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
set default datasource provider
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
# get provider
|
||||||
|
target_provider = (
|
||||||
|
session.query(DatasourceProvider)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
id=credential_id,
|
||||||
|
provider=datasource_provider_id.provider_name,
|
||||||
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if target_provider is None:
|
||||||
|
raise ValueError("provider not found")
|
||||||
|
|
||||||
|
# clear default provider
|
||||||
|
session.query(DatasourceProvider).filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=target_provider.provider,
|
||||||
|
plugin_id=target_provider.plugin_id,
|
||||||
|
is_default=True,
|
||||||
|
).update({"is_default": False})
|
||||||
|
|
||||||
|
# set new default provider
|
||||||
|
target_provider.is_default = True
|
||||||
|
session.commit()
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
def setup_oauth_custom_client_params(
|
def setup_oauth_custom_client_params(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
@ -41,10 +131,6 @@ class DatasourceProviderService:
|
|||||||
"""
|
"""
|
||||||
if client_params is None and enabled is None:
|
if client_params is None and enabled is None:
|
||||||
return
|
return
|
||||||
provider_controller = PluginDatasourceManager()
|
|
||||||
datasource_provider = provider_controller.fetch_datasource_provider(
|
|
||||||
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
|
|
||||||
)
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
tenant_oauth_client_params = (
|
tenant_oauth_client_params = (
|
||||||
session.query(DatasourceOauthTenantParamConfig)
|
session.query(DatasourceOauthTenantParamConfig)
|
||||||
@ -252,7 +338,7 @@ class DatasourceProviderService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
provider_credential_secret_variables = self.extract_secret_variables(
|
provider_credential_secret_variables = self.extract_secret_variables(
|
||||||
tenant_id=tenant_id, provider_id=f"{provider_id}"
|
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type.value
|
||||||
)
|
)
|
||||||
for key, value in credentials.items():
|
for key, value in credentials.items():
|
||||||
if key in provider_credential_secret_variables:
|
if key in provider_credential_secret_variables:
|
||||||
@ -310,7 +396,7 @@ class DatasourceProviderService:
|
|||||||
)
|
)
|
||||||
if credential_valid:
|
if credential_valid:
|
||||||
provider_credential_secret_variables = self.extract_secret_variables(
|
provider_credential_secret_variables = self.extract_secret_variables(
|
||||||
tenant_id=tenant_id, provider_id=f"{provider_id}"
|
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY.value
|
||||||
)
|
)
|
||||||
for key, value in credentials.items():
|
for key, value in credentials.items():
|
||||||
if key in provider_credential_secret_variables:
|
if key in provider_credential_secret_variables:
|
||||||
@ -329,7 +415,7 @@ class DatasourceProviderService:
|
|||||||
else:
|
else:
|
||||||
raise CredentialsValidateFailedError()
|
raise CredentialsValidateFailedError()
|
||||||
|
|
||||||
def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]:
|
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: str) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Extract secret input form variables.
|
Extract secret input form variables.
|
||||||
|
|
||||||
@ -339,7 +425,16 @@ class DatasourceProviderService:
|
|||||||
datasource_provider = self.provider_manager.fetch_datasource_provider(
|
datasource_provider = self.provider_manager.fetch_datasource_provider(
|
||||||
tenant_id=tenant_id, provider_id=provider_id
|
tenant_id=tenant_id, provider_id=provider_id
|
||||||
)
|
)
|
||||||
credential_form_schemas = datasource_provider.declaration.credentials_schema
|
credential_form_schemas = []
|
||||||
|
if credential_type == "api_key":
|
||||||
|
credential_form_schemas = datasource_provider.declaration.credentials_schema
|
||||||
|
elif credential_type == "oauth2":
|
||||||
|
if not datasource_provider.declaration.oauth_schema:
|
||||||
|
raise ValueError("Datasource provider oauth schema not found")
|
||||||
|
credential_form_schemas = datasource_provider.declaration.oauth_schema.credentials_schema
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||||
|
|
||||||
secret_input_form_variables = []
|
secret_input_form_variables = []
|
||||||
for credential_form_schema in credential_form_schemas:
|
for credential_form_schema in credential_form_schemas:
|
||||||
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
|
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
|
||||||
@ -368,11 +463,20 @@ class DatasourceProviderService:
|
|||||||
if not datasource_providers:
|
if not datasource_providers:
|
||||||
return []
|
return []
|
||||||
copy_credentials_list = []
|
copy_credentials_list = []
|
||||||
|
default_provider = (
|
||||||
|
db.session.query(DatasourceProvider.id)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||||
|
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
default_provider_id = default_provider.id if default_provider else None
|
||||||
for datasource_provider in datasource_providers:
|
for datasource_provider in datasource_providers:
|
||||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
credential_secret_variables = self.extract_secret_variables(
|
credential_secret_variables = self.extract_secret_variables(
|
||||||
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
|
tenant_id=tenant_id,
|
||||||
|
provider_id=f"{plugin_id}/{provider}",
|
||||||
|
credential_type=datasource_provider.auth_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Obfuscate provider credentials
|
# Obfuscate provider credentials
|
||||||
@ -387,6 +491,7 @@ class DatasourceProviderService:
|
|||||||
"name": datasource_provider.name,
|
"name": datasource_provider.name,
|
||||||
"avatar_url": datasource_provider.avatar_url,
|
"avatar_url": datasource_provider.avatar_url,
|
||||||
"id": datasource_provider.id,
|
"id": datasource_provider.id,
|
||||||
|
"is_default": default_provider_id and datasource_provider.id == default_provider_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -469,7 +574,9 @@ class DatasourceProviderService:
|
|||||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
credential_secret_variables = self.extract_secret_variables(
|
credential_secret_variables = self.extract_secret_variables(
|
||||||
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
|
tenant_id=tenant_id,
|
||||||
|
provider_id=f"{plugin_id}/{provider}",
|
||||||
|
credential_type=datasource_provider.auth_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Obfuscate provider credentials
|
# Obfuscate provider credentials
|
||||||
@ -507,12 +614,14 @@ class DatasourceProviderService:
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_credential_secret_variables = self.extract_secret_variables(
|
|
||||||
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
|
|
||||||
)
|
|
||||||
if not datasource_provider:
|
if not datasource_provider:
|
||||||
raise ValueError("Datasource provider not found")
|
raise ValueError("Datasource provider not found")
|
||||||
else:
|
else:
|
||||||
|
provider_credential_secret_variables = self.extract_secret_variables(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=f"{plugin_id}/{provider}",
|
||||||
|
credential_type=datasource_provider.auth_type,
|
||||||
|
)
|
||||||
original_credentials = datasource_provider.encrypted_credentials
|
original_credentials = datasource_provider.encrypted_credentials
|
||||||
for key, value in credentials.items():
|
for key, value in credentials.items():
|
||||||
if key in provider_credential_secret_variables:
|
if key in provider_credential_secret_variables:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user