mirror of https://github.com/langgenius/dify.git
migrate to python-socketio
This commit is contained in:
parent
f9aedb2118
commit
4081937e22
13
api/app.py
13
api/app.py
|
|
@ -33,10 +33,17 @@ else:
|
|||
psycogreen.gevent.patch_psycopg()
|
||||
|
||||
from app_factory import create_app
|
||||
from extensions.ext_socketio import ext_socketio
|
||||
|
||||
app = create_app()
|
||||
socketio_app, app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
ext_socketio.run(app, host="0.0.0.0", port=5001, debug=True)
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler
|
||||
|
||||
host = os.environ.get('HOST', '0.0.0.0')
|
||||
port = int(os.environ.get('PORT', 5001))
|
||||
|
||||
print(f"Starting server on {host}:{port}")
|
||||
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
|
|
|
|||
|
|
@ -26,14 +26,21 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> DifyApp:
|
||||
def create_app() -> tuple[any, DifyApp]:
|
||||
start_time = time.perf_counter()
|
||||
app = create_flask_app_with_configs()
|
||||
initialize_extensions(app)
|
||||
|
||||
import socketio
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
sio.app = app
|
||||
socketio_app = socketio.WSGIApp(sio, app)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logging.info(f"Finished create_app ({round((end_time - start_time) * 1000, 2)} ms)")
|
||||
return app
|
||||
return socketio_app, app
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
|
|
@ -57,7 +64,6 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_request_logging,
|
||||
ext_sentry,
|
||||
ext_set_secretkey,
|
||||
ext_socketio,
|
||||
ext_storage,
|
||||
ext_timezone,
|
||||
ext_warnings,
|
||||
|
|
@ -86,7 +92,6 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_commands,
|
||||
ext_otel,
|
||||
ext_request_logging,
|
||||
ext_socketio,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
|
|
|||
|
|
@ -1,57 +1,60 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_socketio import join_room
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_socketio import ext_socketio
|
||||
from extensions.ext_socketio import sio
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from libs.login import login_required
|
||||
|
||||
|
||||
@ext_socketio.on("user_connect")
|
||||
def handle_user_connect(data):
|
||||
@sio.on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event, check login and get user info.
|
||||
"""
|
||||
sid = request.sid
|
||||
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not (current_user := request.environ.get("ws_user")):
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
session = sio.get_session(sid)
|
||||
user_id = session.get('user_id')
|
||||
|
||||
if not user_id:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
old_info_json = redis_client.hget(f"workflow_online_users:{workflow_id}", current_user.id)
|
||||
old_info_json = redis_client.hget(f"workflow_online_users:{workflow_id}", user_id)
|
||||
if old_info_json:
|
||||
old_info = json.loads(old_info_json)
|
||||
old_sid = old_info.get("sid")
|
||||
if old_sid and old_sid != sid:
|
||||
ext_socketio.server.disconnect(sid=old_sid)
|
||||
sio.disconnect(sid=old_sid)
|
||||
|
||||
user_info = {
|
||||
"user_id": current_user.id,
|
||||
"username": current_user.name,
|
||||
"avatar": current_user.avatar,
|
||||
"user_id": user_id,
|
||||
"username": session.get('username', 'Unknown'),
|
||||
"avatar": session.get('avatar', None),
|
||||
"sid": sid,
|
||||
}
|
||||
|
||||
redis_client.hset(f"workflow_online_users:{workflow_id}", current_user.id, json.dumps(user_info))
|
||||
redis_client.set(f"ws_sid_map:{sid}", json.dumps({"workflow_id": workflow_id, "user_id": current_user.id}))
|
||||
redis_client.hset(f"workflow_online_users:{workflow_id}", user_id, json.dumps(user_info))
|
||||
redis_client.set(f"ws_sid_map:{sid}", json.dumps({"workflow_id": workflow_id, "user_id": user_id}))
|
||||
|
||||
join_room(workflow_id)
|
||||
sio.enter_room(sid, workflow_id)
|
||||
broadcast_online_users(workflow_id)
|
||||
|
||||
return {"msg": "connected", "user_id": current_user.id, "sid": sid}
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid}
|
||||
|
||||
|
||||
@ext_socketio.on("disconnect")
|
||||
def handle_disconnect():
|
||||
@sio.on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle user disconnect event, remove user from workflow's online user list.
|
||||
"""
|
||||
sid = request.sid
|
||||
mapping = redis_client.get(f"ws_sid_map:{sid}")
|
||||
if mapping:
|
||||
data = json.loads(mapping)
|
||||
|
|
@ -62,6 +65,7 @@ def handle_disconnect():
|
|||
|
||||
broadcast_online_users(workflow_id)
|
||||
|
||||
|
||||
def broadcast_online_users(workflow_id):
|
||||
"""
|
||||
broadcast online users to the workflow room
|
||||
|
|
@ -73,7 +77,7 @@ def broadcast_online_users(workflow_id):
|
|||
users.append(json.loads(user_info_json))
|
||||
except Exception:
|
||||
continue
|
||||
ext_socketio.emit(
|
||||
sio.emit(
|
||||
"online_users",
|
||||
{"workflow_id": workflow_id, "users": users},
|
||||
room=workflow_id
|
||||
|
|
@ -91,11 +95,11 @@ class OnlineUserApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
workflow_ids = [id.strip() for id in args["workflow_ids"].split(",")]
|
||||
|
||||
|
||||
results = {}
|
||||
for workflow_id in workflow_ids:
|
||||
users_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
|
||||
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
try:
|
||||
|
|
@ -103,38 +107,38 @@ class OnlineUserApi(Resource):
|
|||
except Exception:
|
||||
continue
|
||||
results[workflow_id] = users
|
||||
|
||||
|
||||
return {"data": results}
|
||||
|
||||
|
||||
api.add_resource(OnlineUserApi, "/online-users")
|
||||
|
||||
|
||||
@ext_socketio.on("collaboration_event")
|
||||
def handle_collaboration_event(data):
|
||||
@sio.on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouseMove
|
||||
2. openPanel
|
||||
|
||||
|
||||
"""
|
||||
sid = request.sid
|
||||
mapping = redis_client.get(f"ws_sid_map:{sid}")
|
||||
|
||||
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
|
||||
mapping_data = json.loads(mapping)
|
||||
workflow_id = mapping_data["workflow_id"]
|
||||
user_id = mapping_data["user_id"]
|
||||
|
||||
|
||||
event_type = data.get("type")
|
||||
event_data = data.get("data")
|
||||
timestamp = data.get("timestamp", int(time.time()))
|
||||
|
||||
|
||||
if not event_type or not event_data:
|
||||
return {"msg": "invalid event data"}, 400
|
||||
|
||||
ext_socketio.emit(
|
||||
|
||||
sio.emit(
|
||||
"collaboration_update",
|
||||
{
|
||||
"type": event_type,
|
||||
|
|
@ -145,5 +149,5 @@ def handle_collaboration_event(data):
|
|||
room=workflow_id,
|
||||
skip_sid=sid
|
||||
)
|
||||
|
||||
|
||||
return {"msg": "event_broadcasted"}
|
||||
|
|
|
|||
|
|
@ -3,13 +3,12 @@ import json
|
|||
import flask_login # type: ignore
|
||||
from flask import Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from flask_socketio import disconnect
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_socketio import ext_socketio
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.model import AppMCPServer, EndUser
|
||||
|
|
@ -115,8 +114,8 @@ def unauthorized_handler():
|
|||
)
|
||||
|
||||
|
||||
@ext_socketio.on('connect')
|
||||
def socket_connect(auth):
|
||||
@sio.on('connect')
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
|
|
@ -124,24 +123,28 @@ def socket_connect(auth):
|
|||
if auth and isinstance(auth, dict):
|
||||
token = auth.get('token')
|
||||
if not token:
|
||||
disconnect()
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
disconnect()
|
||||
return False
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
disconnect()
|
||||
return False
|
||||
|
||||
request.environ['ws_user'] = user
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
sio.save_session(sid, {
|
||||
'user_id': user.id,
|
||||
'username': user.name,
|
||||
'avatar': user.avatar
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
disconnect()
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,5 @@
|
|||
from flask_socketio import SocketIO
|
||||
import socketio
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
ext_socketio = SocketIO()
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
ext_socketio.init_app(app, async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS)
|
||||
sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS)
|
||||
|
|
|
|||
Loading…
Reference in New Issue