diff --git a/public/src/sockets.js b/public/src/sockets.js index 6d3ccfbb80..acd3ba3dbe 100644 --- a/public/src/sockets.js +++ b/public/src/sockets.js @@ -15,6 +15,9 @@ app = window.app || {}; reconnectionDelay: config.reconnectionDelay, transports: config.socketioTransports, path: config.relative_path + '/socket.io', + query: { + _csrf: config.csrf_token, + }, }; window.socket = io(config.websocketAddress, ioParams); diff --git a/src/middleware/csrf.js b/src/middleware/csrf.js index f6af0c625b..be5b0761fe 100644 --- a/src/middleware/csrf.js +++ b/src/middleware/csrf.js @@ -5,12 +5,15 @@ const { csrfSync } = require('csrf-sync'); const { generateToken, csrfSynchronisedProtection, + isRequestValid, } = csrfSync({ getTokenFromRequest: (req) => { if (req.headers['x-csrf-token']) { return req.headers['x-csrf-token']; - } else if (req.body.csrf_token) { + } else if (req.body && req.body.csrf_token) { return req.body.csrf_token; + } else if (req.query) { + return req.query._csrf; } }, size: 64, @@ -19,4 +22,5 @@ const { module.exports = { generateToken, csrfSynchronisedProtection, + isRequestValid, }; diff --git a/src/socket.io/index.js b/src/socket.io/index.js index 963267ed9a..d457afefbc 100644 --- a/src/socket.io/index.js +++ b/src/socket.io/index.js @@ -34,13 +34,25 @@ Sockets.init = async function (server) { } } - io.use(authorize); - io.on('connection', onConnection); const opts = { transports: nconf.get('socket.io:transports') || ['polling', 'websocket'], cookie: false, + allowRequest: (req, callback) => { + authorize(req, (err) => { + if (err) { + return callback(err); + } + const csrf = require('../middleware/csrf'); + const isValid = csrf.isRequestValid({ + session: req.session || {}, + query: req._query, + headers: req.headers, + }); + callback(null, isValid); + }); + }, }; /* * Restrict socket.io listener to cookie domain. If none is set, infer based on url. @@ -62,7 +74,11 @@ Sockets.init = async function (server) { }; function onConnection(socket) { - socket.ip = (socket.request.headers['x-forwarded-for'] || socket.request.connection.remoteAddress || '').split(',')[0]; + socket.uid = socket.request.uid; + socket.ip = ( + socket.request.headers['x-forwarded-for'] || + socket.request.connection.remoteAddress || '' + ).split(',')[0]; socket.request.ip = socket.ip; logger.io_one(socket, socket.uid); @@ -231,9 +247,7 @@ async function validateSession(socket, errorMsg) { const cookieParserAsync = util.promisify((req, callback) => cookieParser(req, {}, err => callback(err))); -async function authorize(socket, callback) { - const { request } = socket; - +async function authorize(request, callback) { if (!request) { return callback(new Error('[[error:not-authorized]]')); } @@ -246,15 +260,13 @@ async function authorize(socket, callback) { }); const sessionData = await getSessionAsync(sessionId); - + request.session = sessionData; + let uid = 0; if (sessionData && sessionData.passport && sessionData.passport.user) { - request.session = sessionData; - socket.uid = parseInt(sessionData.passport.user, 10); - } else { - socket.uid = 0; + uid = parseInt(sessionData.passport.user, 10); } - request.uid = socket.uid; - callback(); + request.uid = uid; + callback(null, uid); } Sockets.in = function (room) { diff --git a/test/helpers/index.js b/test/helpers/index.js index b79bf66e2d..4adcb1b0dd 100644 --- a/test/helpers/index.js +++ b/test/helpers/index.js @@ -96,7 +96,7 @@ helpers.logoutUser = function (jar, callback) { }); }; -helpers.connectSocketIO = function (res, callback) { +helpers.connectSocketIO = function (res, csrf_token, callback) { const io = require('socket.io-client'); let cookies = res.headers['set-cookie']; cookies = cookies.filter(c => /express.sid=[^;]+;/.test(c)); @@ -107,6 +107,9 @@ helpers.connectSocketIO = function (res, callback) { Origin: nconf.get('url'), Cookie: cookie, }, + query: { + _csrf: csrf_token, + }, }); socket.on('connect', () => { diff --git a/test/socket.io.js b/test/socket.io.js index 726f6f71ad..4e2292d4e5 100644 --- a/test/socket.io.js +++ b/test/socket.io.js @@ -73,7 +73,7 @@ describe('socket.io', () => { }, (err, res) => { assert.ifError(err); - helpers.connectSocketIO(res, (err, _io) => { + helpers.connectSocketIO(res, body.csrf_token, (err, _io) => { io = _io; assert.ifError(err);