From 678d6ffe2bb76f953fa38900239ede56fbf301c3 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 26 May 2025 17:00:16 +0800 Subject: [PATCH] r2 --- .../console/auth/data_source_oauth.py | 21 ------ api/controllers/console/auth/oauth.py | 67 ++++++++++++++++++- .../datasets/rag_pipeline/datasource_oauth.py | 56 ++++++++++++++++ .../rag_pipeline/rag_pipeline_workflow.py | 2 +- api/core/app/apps/pipeline/pipeline_runner.py | 4 +- api/models/oauth.py | 47 +++++++++++++ 6 files changed, 171 insertions(+), 26 deletions(-) create mode 100644 api/controllers/console/datasets/rag_pipeline/datasource_oauth.py create mode 100644 api/models/oauth.py diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 0f3e2582c4..5299064e17 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -8,7 +8,6 @@ from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from core.plugin.impl.oauth import OAuthHandler from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -110,29 +109,9 @@ class OAuthDataSourceSync(Resource): return {"result": "success"}, 200 -class DatasourcePluginOauthApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, datasource_type, datasource_name): - # Check user role first - if not current_user.is_editor: - raise Forbidden() - # get all builtin providers - oauth_handler = OAuthHandler() - providers = oauth_handler.get_authorization_url( - current_user.current_tenant.id, - current_user.id, - datasource_type, - datasource_name, - system_credentials={} - ) - return providers - api.add_resource(OAuthDataSource, "/oauth/data-source/") api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") -api.add_resource(DatasourcePluginOauthApi, "/oauth/plugin/datasource//") \ No newline at end of file diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 395367c9e2..d8576c3879 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -2,26 +2,30 @@ import logging from datetime import UTC, datetime from typing import Optional +from flask_login import current_user import requests from flask import current_app, redirect, request from flask_restful import Resource from sqlalchemy import select from sqlalchemy.orm import Session -from werkzeug.exceptions import Unauthorized +from werkzeug.exceptions import Unauthorized, Forbidden, NotFound from configs import dify_config from constants.languages import languages +from controllers.console.wraps import account_initialization_required, setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import extract_remote_ip +from libs.login import login_required from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models import Account from models.account import AccountStatus +from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.feature_service import FeatureService - +from core.plugin.impl.oauth import OAuthHandler from .. import api @@ -181,5 +185,64 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): return account +class PluginOauthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider, plugin_id): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + # get all plugin oauth configs + plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( + provider=provider, + plugin_id=plugin_id + ).first() + if not plugin_oauth_config: + raise NotFound() + oauth_handler = OAuthHandler() + response = oauth_handler.get_authorization_url( + current_user.current_tenant.id, + current_user.id, + plugin_id, + provider, + system_credentials=plugin_oauth_config.system_credentials + ) + return response.model_dump() + +class PluginOauthCallback(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider, plugin_id): + oauth_handler = OAuthHandler() + plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( + provider=provider, + plugin_id=plugin_id + ).first() + if not plugin_oauth_config: + raise NotFound() + credentials = oauth_handler.get_credentials( + current_user.current_tenant.id, + current_user.id, + plugin_id, + provider, + system_credentials=plugin_oauth_config.system_credentials, + request=request + ) + datasource_provider = DatasourceProvider( + datasource_name=plugin_oauth_config.datasource_name, + plugin_id=plugin_id, + provider=provider, + auth_type="oauth", + encrypted_credentials=credentials + ) + db.session.add(datasource_provider) + db.session.commit() + return redirect(f"{dify_config.CONSOLE_WEB_URL}") + + api.add_resource(OAuthLogin, "/oauth/login/") api.add_resource(OAuthCallback, "/oauth/authorize/") +api.add_resource(PluginOauthApi, "/oauth/plugin/provider//plugin/") +api.add_resource(PluginOauthCallback, "/oauth/plugin/callback//plugin/") diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py b/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py new file mode 100644 index 0000000000..b105606c04 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py @@ -0,0 +1,56 @@ +from typing import cast + +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from core.plugin.impl.datasource import PluginDatasourceManager +from extensions.ext_database import db +from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields +from libs.login import login_required +from models import Account +from models.dataset import Pipeline +from services.app_dsl_service import ImportStatus +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + + +class DatasourcePluginOauthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, datasource_type, datasource_name): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + # get all builtin providers + manager = PluginDatasourceManager() + providers = manager.get_provider_oauth_url() + return providers + + + + +# Import Rag Pipeline +api.add_resource( + DatasourcePluginOauthApi, + "/datasource///oauth", +) +api.add_resource( + RagPipelineImportConfirmApi, + "/rag/pipelines/imports//confirm", +) +api.add_resource( + RagPipelineImportCheckDependenciesApi, + "/rag/pipelines/imports//check-dependencies", +) +api.add_resource( + RagPipelineExportApi, + "/rag/pipelines//exports", +) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 63fe4b7f87..bbeaa33341 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -279,7 +279,7 @@ class PublishedRagPipelineRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("datasource_info", type=list, required=True, location="json") + parser.add_argument("datasource_info_list", type=list, required=True, location="json") parser.add_argument("start_node_id", type=str, required=True, location="json") parser.add_argument("is_preview", type=bool, required=True, location="json", default=False) args = parser.parse_args() diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index dd9eade0a5..23dbfef70d 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -104,7 +104,7 @@ class PipelineRunner(WorkflowBasedAppRunner): SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type, SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info, - SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from, + SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value, } variable_pool = VariablePool( @@ -199,4 +199,4 @@ class PipelineRunner(WorkflowBasedAppRunner): if not graph: raise ValueError("graph not found in workflow") - return graph \ No newline at end of file + return graph diff --git a/api/models/oauth.py b/api/models/oauth.py new file mode 100644 index 0000000000..f24c3c6723 --- /dev/null +++ b/api/models/oauth.py @@ -0,0 +1,47 @@ + +from datetime import datetime +from json import JSONDecodeError +from typing import Any, cast + +from sqlalchemy import func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped + +from configs import dify_config +from extensions.ext_storage import storage +from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule + +from .account import Account +from .base import Base +from .engine import db +from .model import App, Tag, TagBinding, UploadFile +from .types import StringUUID + + +class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] + __tablename__ = "datasource_oauth_params" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"), + db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + datasource_name: Mapped[str] = db.Column(db.String(255), nullable=False) + plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) + provider: Mapped[str] = db.Column(db.String(255), nullable=False) + system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) + +class DatasourceProvider(Base): + __tablename__ = "datasource_providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"), + db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"), + ) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + datasource_name: Mapped[str] = db.Column(db.String(255), nullable=False) + plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) + provider: Mapped[str] = db.Column(db.String(255), nullable=False) + auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) + encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) + updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)