From b0868d913694d362eccc10e8afcec4859db67d74 Mon Sep 17 00:00:00 2001 From: hjlarry Date: Thu, 17 Jul 2025 17:16:38 +0800 Subject: [PATCH] fix websocket auth --- api/controllers/console/__init__.py | 1 + api/controllers/console/app/online_user.py | 6 ++-- api/extensions/ext_login.py | 34 +++++++++++++++++++++- web/service/demo/online-user.ts | 4 ++- 4 files changed, 40 insertions(+), 5 deletions(-) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index e25f92399c..5055fbcc01 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -59,6 +59,7 @@ from .app import ( mcp_server, message, model_config, + online_user, ops_trace, site, statistic, diff --git a/api/controllers/console/app/online_user.py b/api/controllers/console/app/online_user.py index 6f05aae5c5..fc03649d06 100644 --- a/api/controllers/console/app/online_user.py +++ b/api/controllers/console/app/online_user.py @@ -1,21 +1,21 @@ import json from flask import request -from flask_login import current_user, login_required from extensions.ext_redis import redis_client from extensions.ext_socketio import ext_socketio + @ext_socketio.on("user_connect") -@login_required def handle_user_connect(data): """ Handle user connect event, check login and get user info. """ - sid = request.sid workflow_id = data.get("workflow_id") + if not (current_user := request.environ.get("ws_user")): + return {"msg": "unauthorized"}, 401 old_info_json = redis_client.hget(f"workflow_online_users:{workflow_id}", current_user.id) if old_info_json: diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 11d1856ac4..3ff8ec0390 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,13 +1,15 @@ import json import flask_login # type: ignore -from flask import Response, request +from flask import Response, request, g +from flask_socketio import disconnect from flask_login import user_loaded_from_request, user_logged_in from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config from dify_app import DifyApp from extensions.ext_database import db +from extensions.ext_socketio import ext_socketio from libs.passport import PassportService from models.account import Account, Tenant, TenantAccountJoin from models.model import AppMCPServer, EndUser @@ -113,5 +115,35 @@ def unauthorized_handler(): ) +@ext_socketio.on('connect') +def socket_connect(auth): + """ + WebSocket connect event, do authentication here. + """ + token = None + if auth and isinstance(auth, dict): + token = auth.get('token') + if not token: + disconnect() + return False + + try: + decoded = PassportService().verify(token) + user_id = decoded.get("user_id") + if not user_id: + disconnect() + return False + user = AccountService.load_logged_in_account(account_id=user_id) + if not user: + disconnect() + return False + + request.environ['ws_user'] = user + + except Exception: + disconnect() + return False + + def init_app(app: DifyApp): login_manager.init_app(app) diff --git a/web/service/demo/online-user.ts b/web/service/demo/online-user.ts index 34e608fe85..f2e2db87b7 100644 --- a/web/service/demo/online-user.ts +++ b/web/service/demo/online-user.ts @@ -15,16 +15,18 @@ export function connectOnlineUserWebSocket(appId: string): Socket { socket.disconnect() const url = process.env.NEXT_PUBLIC_SOCKET_URL || 'ws://localhost:5001' + const token = localStorage.getItem('console_token') socket = io(url, { path: '/socket.io', transports: ['websocket'], - query: { app_id: appId }, + auth: { token }, withCredentials: true, }) // Add your event listeners here socket.on('connect', () => { + socket?.emit('user_connect', { workflow_id: appId }) console.log('WebSocket connected') })