From 4081937e22e72a6078f5f28a9595d0107bfb8a97 Mon Sep 17 00:00:00 2001 From: hjlarry Date: Mon, 21 Jul 2025 14:57:03 +0800 Subject: [PATCH] migrate to python-socketio --- api/app.py | 13 +++- api/app_factory.py | 13 ++-- api/controllers/console/app/online_user.py | 72 ++++++++++++---------- api/extensions/ext_login.py | 27 ++++---- api/extensions/ext_socketio.py | 9 +-- 5 files changed, 74 insertions(+), 60 deletions(-) diff --git a/api/app.py b/api/app.py index 6271c60c0b..080df69b57 100644 --- a/api/app.py +++ b/api/app.py @@ -33,10 +33,17 @@ else: psycogreen.gevent.patch_psycopg() from app_factory import create_app - from extensions.ext_socketio import ext_socketio - app = create_app() + socketio_app, app = create_app() celery = app.extensions["celery"] if __name__ == "__main__": - ext_socketio.run(app, host="0.0.0.0", port=5001, debug=True) + from gevent import pywsgi + from geventwebsocket.handler import WebSocketHandler + + host = os.environ.get('HOST', '0.0.0.0') + port = int(os.environ.get('PORT', 5001)) + + print(f"Starting server on {host}:{port}") + server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler) + server.serve_forever() diff --git a/api/app_factory.py b/api/app_factory.py index d13593d7c0..6ad46c0a45 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -26,14 +26,21 @@ def create_flask_app_with_configs() -> DifyApp: return dify_app -def create_app() -> DifyApp: +def create_app() -> tuple[any, DifyApp]: start_time = time.perf_counter() app = create_flask_app_with_configs() initialize_extensions(app) + + import socketio + + from extensions.ext_socketio import sio + sio.app = app + socketio_app = socketio.WSGIApp(sio, app) + end_time = time.perf_counter() if dify_config.DEBUG: logging.info(f"Finished create_app ({round((end_time - start_time) * 1000, 2)} ms)") - return app + return socketio_app, app def initialize_extensions(app: DifyApp): @@ -57,7 +64,6 @@ def initialize_extensions(app: DifyApp): ext_request_logging, ext_sentry, ext_set_secretkey, - ext_socketio, ext_storage, ext_timezone, ext_warnings, @@ -86,7 +92,6 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_otel, ext_request_logging, - ext_socketio, ] for ext in extensions: short_name = ext.__name__.split(".")[-1] diff --git a/api/controllers/console/app/online_user.py b/api/controllers/console/app/online_user.py index da50d1508d..7d5653dd14 100644 --- a/api/controllers/console/app/online_user.py +++ b/api/controllers/console/app/online_user.py @@ -1,57 +1,60 @@ import json import time -from flask import request from flask_restful import Resource, marshal_with, reqparse -from flask_socketio import join_room from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_redis import redis_client -from extensions.ext_socketio import ext_socketio +from extensions.ext_socketio import sio from fields.online_user_fields import online_user_list_fields from libs.login import login_required -@ext_socketio.on("user_connect") -def handle_user_connect(data): +@sio.on("user_connect") +def handle_user_connect(sid, data): """ Handle user connect event, check login and get user info. """ - sid = request.sid + workflow_id = data.get("workflow_id") - if not (current_user := request.environ.get("ws_user")): + if not workflow_id: + return {"msg": "workflow_id is required"}, 400 + + session = sio.get_session(sid) + user_id = session.get('user_id') + + if not user_id: return {"msg": "unauthorized"}, 401 - old_info_json = redis_client.hget(f"workflow_online_users:{workflow_id}", current_user.id) + old_info_json = redis_client.hget(f"workflow_online_users:{workflow_id}", user_id) if old_info_json: old_info = json.loads(old_info_json) old_sid = old_info.get("sid") if old_sid and old_sid != sid: - ext_socketio.server.disconnect(sid=old_sid) + sio.disconnect(sid=old_sid) user_info = { - "user_id": current_user.id, - "username": current_user.name, - "avatar": current_user.avatar, + "user_id": user_id, + "username": session.get('username', 'Unknown'), + "avatar": session.get('avatar', None), "sid": sid, } - redis_client.hset(f"workflow_online_users:{workflow_id}", current_user.id, json.dumps(user_info)) - redis_client.set(f"ws_sid_map:{sid}", json.dumps({"workflow_id": workflow_id, "user_id": current_user.id})) + redis_client.hset(f"workflow_online_users:{workflow_id}", user_id, json.dumps(user_info)) + redis_client.set(f"ws_sid_map:{sid}", json.dumps({"workflow_id": workflow_id, "user_id": user_id})) - join_room(workflow_id) + sio.enter_room(sid, workflow_id) broadcast_online_users(workflow_id) - return {"msg": "connected", "user_id": current_user.id, "sid": sid} + return {"msg": "connected", "user_id": user_id, "sid": sid} -@ext_socketio.on("disconnect") -def handle_disconnect(): +@sio.on("disconnect") +def handle_disconnect(sid): """ Handle user disconnect event, remove user from workflow's online user list. """ - sid = request.sid mapping = redis_client.get(f"ws_sid_map:{sid}") if mapping: data = json.loads(mapping) @@ -62,6 +65,7 @@ def handle_disconnect(): broadcast_online_users(workflow_id) + def broadcast_online_users(workflow_id): """ broadcast online users to the workflow room @@ -73,7 +77,7 @@ def broadcast_online_users(workflow_id): users.append(json.loads(user_info_json)) except Exception: continue - ext_socketio.emit( + sio.emit( "online_users", {"workflow_id": workflow_id, "users": users}, room=workflow_id @@ -91,11 +95,11 @@ class OnlineUserApi(Resource): args = parser.parse_args() workflow_ids = [id.strip() for id in args["workflow_ids"].split(",")] - + results = {} for workflow_id in workflow_ids: users_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}") - + users = [] for _, user_info_json in users_json.items(): try: @@ -103,38 +107,38 @@ class OnlineUserApi(Resource): except Exception: continue results[workflow_id] = users - + return {"data": results} + api.add_resource(OnlineUserApi, "/online-users") -@ext_socketio.on("collaboration_event") -def handle_collaboration_event(data): +@sio.on("collaboration_event") +def handle_collaboration_event(sid, data): """ Handle general collaboration events, include: 1. mouseMove 2. openPanel - + """ - sid = request.sid mapping = redis_client.get(f"ws_sid_map:{sid}") - + if not mapping: return {"msg": "unauthorized"}, 401 - + mapping_data = json.loads(mapping) workflow_id = mapping_data["workflow_id"] user_id = mapping_data["user_id"] - + event_type = data.get("type") event_data = data.get("data") timestamp = data.get("timestamp", int(time.time())) - + if not event_type or not event_data: return {"msg": "invalid event data"}, 400 - - ext_socketio.emit( + + sio.emit( "collaboration_update", { "type": event_type, @@ -145,5 +149,5 @@ def handle_collaboration_event(data): room=workflow_id, skip_sid=sid ) - + return {"msg": "event_broadcasted"} diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 39be9929c7..72069465fc 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -3,13 +3,12 @@ import json import flask_login # type: ignore from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in -from flask_socketio import disconnect from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config from dify_app import DifyApp from extensions.ext_database import db -from extensions.ext_socketio import ext_socketio +from extensions.ext_socketio import sio from libs.passport import PassportService from models.account import Account, Tenant, TenantAccountJoin from models.model import AppMCPServer, EndUser @@ -115,8 +114,8 @@ def unauthorized_handler(): ) -@ext_socketio.on('connect') -def socket_connect(auth): +@sio.on('connect') +def socket_connect(sid, environ, auth): """ WebSocket connect event, do authentication here. """ @@ -124,24 +123,28 @@ def socket_connect(auth): if auth and isinstance(auth, dict): token = auth.get('token') if not token: - disconnect() return False try: decoded = PassportService().verify(token) user_id = decoded.get("user_id") if not user_id: - disconnect() - return False - user = AccountService.load_logged_in_account(account_id=user_id) - if not user: - disconnect() return False - request.environ['ws_user'] = user + with sio.app.app_context(): + user = AccountService.load_logged_in_account(account_id=user_id) + if not user: + return False + + sio.save_session(sid, { + 'user_id': user.id, + 'username': user.name, + 'avatar': user.avatar + }) + + return True except Exception: - disconnect() return False diff --git a/api/extensions/ext_socketio.py b/api/extensions/ext_socketio.py index 3f5f5db976..470e0b08e2 100644 --- a/api/extensions/ext_socketio.py +++ b/api/extensions/ext_socketio.py @@ -1,10 +1,5 @@ -from flask_socketio import SocketIO +import socketio from configs import dify_config -from dify_app import DifyApp -ext_socketio = SocketIO() - - -def init_app(app: DifyApp): - ext_socketio.init_app(app, async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS) +sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS)