migrate to python-socketio

This commit is contained in:
hjlarry 2025-07-21 14:57:03 +08:00
parent f9aedb2118
commit 4081937e22
5 changed files with 74 additions and 60 deletions

View File

@ -33,10 +33,17 @@ else:
psycogreen.gevent.patch_psycopg()
from app_factory import create_app
from extensions.ext_socketio import ext_socketio
app = create_app()
socketio_app, app = create_app()
celery = app.extensions["celery"]
if __name__ == "__main__":
ext_socketio.run(app, host="0.0.0.0", port=5001, debug=True)
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
host = os.environ.get('HOST', '0.0.0.0')
port = int(os.environ.get('PORT', 5001))
print(f"Starting server on {host}:{port}")
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
server.serve_forever()

View File

@ -26,14 +26,21 @@ def create_flask_app_with_configs() -> DifyApp:
return dify_app
def create_app() -> DifyApp:
def create_app() -> tuple[any, DifyApp]:
start_time = time.perf_counter()
app = create_flask_app_with_configs()
initialize_extensions(app)
import socketio
from extensions.ext_socketio import sio
sio.app = app
socketio_app = socketio.WSGIApp(sio, app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logging.info(f"Finished create_app ({round((end_time - start_time) * 1000, 2)} ms)")
return app
return socketio_app, app
def initialize_extensions(app: DifyApp):
@ -57,7 +64,6 @@ def initialize_extensions(app: DifyApp):
ext_request_logging,
ext_sentry,
ext_set_secretkey,
ext_socketio,
ext_storage,
ext_timezone,
ext_warnings,
@ -86,7 +92,6 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_otel,
ext_request_logging,
ext_socketio,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]

View File

@ -1,57 +1,60 @@
import json
import time
from flask import request
from flask_restful import Resource, marshal_with, reqparse
from flask_socketio import join_room
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_redis import redis_client
from extensions.ext_socketio import ext_socketio
from extensions.ext_socketio import sio
from fields.online_user_fields import online_user_list_fields
from libs.login import login_required
@ext_socketio.on("user_connect")
def handle_user_connect(data):
@sio.on("user_connect")
def handle_user_connect(sid, data):
"""
Handle user connect event, check login and get user info.
"""
sid = request.sid
workflow_id = data.get("workflow_id")
if not (current_user := request.environ.get("ws_user")):
if not workflow_id:
return {"msg": "workflow_id is required"}, 400
session = sio.get_session(sid)
user_id = session.get('user_id')
if not user_id:
return {"msg": "unauthorized"}, 401
old_info_json = redis_client.hget(f"workflow_online_users:{workflow_id}", current_user.id)
old_info_json = redis_client.hget(f"workflow_online_users:{workflow_id}", user_id)
if old_info_json:
old_info = json.loads(old_info_json)
old_sid = old_info.get("sid")
if old_sid and old_sid != sid:
ext_socketio.server.disconnect(sid=old_sid)
sio.disconnect(sid=old_sid)
user_info = {
"user_id": current_user.id,
"username": current_user.name,
"avatar": current_user.avatar,
"user_id": user_id,
"username": session.get('username', 'Unknown'),
"avatar": session.get('avatar', None),
"sid": sid,
}
redis_client.hset(f"workflow_online_users:{workflow_id}", current_user.id, json.dumps(user_info))
redis_client.set(f"ws_sid_map:{sid}", json.dumps({"workflow_id": workflow_id, "user_id": current_user.id}))
redis_client.hset(f"workflow_online_users:{workflow_id}", user_id, json.dumps(user_info))
redis_client.set(f"ws_sid_map:{sid}", json.dumps({"workflow_id": workflow_id, "user_id": user_id}))
join_room(workflow_id)
sio.enter_room(sid, workflow_id)
broadcast_online_users(workflow_id)
return {"msg": "connected", "user_id": current_user.id, "sid": sid}
return {"msg": "connected", "user_id": user_id, "sid": sid}
@ext_socketio.on("disconnect")
def handle_disconnect():
@sio.on("disconnect")
def handle_disconnect(sid):
"""
Handle user disconnect event, remove user from workflow's online user list.
"""
sid = request.sid
mapping = redis_client.get(f"ws_sid_map:{sid}")
if mapping:
data = json.loads(mapping)
@ -62,6 +65,7 @@ def handle_disconnect():
broadcast_online_users(workflow_id)
def broadcast_online_users(workflow_id):
"""
broadcast online users to the workflow room
@ -73,7 +77,7 @@ def broadcast_online_users(workflow_id):
users.append(json.loads(user_info_json))
except Exception:
continue
ext_socketio.emit(
sio.emit(
"online_users",
{"workflow_id": workflow_id, "users": users},
room=workflow_id
@ -91,11 +95,11 @@ class OnlineUserApi(Resource):
args = parser.parse_args()
workflow_ids = [id.strip() for id in args["workflow_ids"].split(",")]
results = {}
for workflow_id in workflow_ids:
users_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
users = []
for _, user_info_json in users_json.items():
try:
@ -103,38 +107,38 @@ class OnlineUserApi(Resource):
except Exception:
continue
results[workflow_id] = users
return {"data": results}
api.add_resource(OnlineUserApi, "/online-users")
@ext_socketio.on("collaboration_event")
def handle_collaboration_event(data):
@sio.on("collaboration_event")
def handle_collaboration_event(sid, data):
"""
Handle general collaboration events, include:
1. mouseMove
2. openPanel
"""
sid = request.sid
mapping = redis_client.get(f"ws_sid_map:{sid}")
if not mapping:
return {"msg": "unauthorized"}, 401
mapping_data = json.loads(mapping)
workflow_id = mapping_data["workflow_id"]
user_id = mapping_data["user_id"]
event_type = data.get("type")
event_data = data.get("data")
timestamp = data.get("timestamp", int(time.time()))
if not event_type or not event_data:
return {"msg": "invalid event data"}, 400
ext_socketio.emit(
sio.emit(
"collaboration_update",
{
"type": event_type,
@ -145,5 +149,5 @@ def handle_collaboration_event(data):
room=workflow_id,
skip_sid=sid
)
return {"msg": "event_broadcasted"}

View File

@ -3,13 +3,12 @@ import json
import flask_login # type: ignore
from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from flask_socketio import disconnect
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from dify_app import DifyApp
from extensions.ext_database import db
from extensions.ext_socketio import ext_socketio
from extensions.ext_socketio import sio
from libs.passport import PassportService
from models.account import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser
@ -115,8 +114,8 @@ def unauthorized_handler():
)
@ext_socketio.on('connect')
def socket_connect(auth):
@sio.on('connect')
def socket_connect(sid, environ, auth):
"""
WebSocket connect event, do authentication here.
"""
@ -124,24 +123,28 @@ def socket_connect(auth):
if auth and isinstance(auth, dict):
token = auth.get('token')
if not token:
disconnect()
return False
try:
decoded = PassportService().verify(token)
user_id = decoded.get("user_id")
if not user_id:
disconnect()
return False
user = AccountService.load_logged_in_account(account_id=user_id)
if not user:
disconnect()
return False
request.environ['ws_user'] = user
with sio.app.app_context():
user = AccountService.load_logged_in_account(account_id=user_id)
if not user:
return False
sio.save_session(sid, {
'user_id': user.id,
'username': user.name,
'avatar': user.avatar
})
return True
except Exception:
disconnect()
return False

View File

@ -1,10 +1,5 @@
from flask_socketio import SocketIO
import socketio
from configs import dify_config
from dify_app import DifyApp
ext_socketio = SocketIO()
def init_app(app: DifyApp):
ext_socketio.init_app(app, async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS)
sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS)