diff --git a/src/common/channel.js b/src/common/channel.js index d8175ce..376bdf3 100644 --- a/src/common/channel.js +++ b/src/common/channel.js @@ -17,7 +17,7 @@ export const validateSubreddit = (name) => { return !!name.match(/^[A-Za-z0-9][A-Za-z0-9_]{2,20}$/i); }; -export const validateUsername = (name) => name.match(/^[\w-]{3,20}/); +export const validateUsername = (name) => !!name.match(/^[\w-]{3,20}$/); export const validateChannel = (channel) => { if (!channel) { diff --git a/src/server/packet/add-favorite.js b/src/server/packet/add-favorite.js index 879b30c..f8f8337 100644 --- a/src/server/packet/add-favorite.js +++ b/src/server/packet/add-favorite.js @@ -1,10 +1,15 @@ +import {validateChannel} from '~/common/channel'; import AddFavorite from '~/common/packets/add-favorite.packet'; -import {allModels} from '~/server/models/registrar'; +import ValidationError from './validation-error'; export default { Packet: AddFavorite, - validator: async () => true, + validator: async ({data: channel}) => { + if (!validateChannel(channel)) { + throw new ValidationError({code: 400, reason: 'Malformed channel'}); + } + }, responder: async (packet, socket) => { const {req} = socket; await req.user.createFavorite({channel: packet.data}); diff --git a/src/server/packet/add-friend.js b/src/server/packet/add-friend.js index 5da784a..a8b5379 100644 --- a/src/server/packet/add-friend.js +++ b/src/server/packet/add-friend.js @@ -10,10 +10,10 @@ export default { limiter: {points: 20, duration: 60}, validator: async ({data: {nameOrStatus}}) => { if (!validateUsername(nameOrStatus)) { - throw new ValidationError('Invalid username'); + throw new ValidationError({code: 400, reason: 'Malformed username'}); } }, - responder: async (packet, socket) => { + responder: async ({data: {nameOrStatus}}, socket) => { const {req} = socket; const { Friendship, @@ -21,8 +21,8 @@ export default { } = allModels(); const adderId = req.user.id; const user = ( - await User.findOne({where: {redditUsername: packet.data.nameOrStatus}}) - || await User.create({redditUsername: packet.data.nameOrStatus}) + await User.findOne({where: {redditUsername: nameOrStatus}}) + || await User.create({redditUsername: nameOrStatus}) ); const addeeId = user.id; let friendship = await Friendship.findOne({where: {adderId: addeeId, addeeId: adderId}}); diff --git a/src/server/packet/block.js b/src/server/packet/block.js index 9927d79..6432fd2 100644 --- a/src/server/packet/block.js +++ b/src/server/packet/block.js @@ -7,9 +7,16 @@ import {allModels} from '~/server/models/registrar'; import {removeFavoritedUser} from './remove-favorite'; +import ValidationError from './validation-error'; + export default { Packet: Block, - validator: async () => true, + validator: async ({data: id}) => { + const {User} = allModels(); + if (!await User.count({where: {id}})) { + throw new ValidationError({code: 400, reason: 'No such user'}); + } + }, responder: async (packet, socket) => { const {req} = socket; const id = packet.data; diff --git a/src/server/packet/join.js b/src/server/packet/join.js index 2459278..d244fcf 100644 --- a/src/server/packet/join.js +++ b/src/server/packet/join.js @@ -3,6 +3,7 @@ import {promisify} from 'util'; import {ServerSocket} from '@avocado/net/server/socket'; +import {validateChannel} from '~/common/channel'; import Join from '~/common/packets/join.packet'; import { @@ -10,6 +11,8 @@ import { channelUsers, } from '~/server/entry'; +import ValidationError from './validation-error'; + export const userJoin = async (channel, socket) => { const userId = '/r/anonymous' === channel ? 0 : socket.handshake.userId; const users = await channelUsers(socket.handshake, channel); @@ -21,10 +24,13 @@ export const userJoin = async (channel, socket) => { export default { Packet: Join, - validator: async () => true, - responder: async (packet, socket) => { + validator: async () => { + if (!validateChannel()) { + throw new ValidationError({code: 400, reason: 'Malformed channel'}); + } + }, + responder: async ({data: channel}, socket) => { const {req} = socket; - const {channel} = packet.data; await userJoin(channel, socket.socket); return channelState(req, channel); }, diff --git a/src/server/packet/leave.js b/src/server/packet/leave.js index 707298c..acc458b 100644 --- a/src/server/packet/leave.js +++ b/src/server/packet/leave.js @@ -1,10 +1,13 @@ // eslint-disable-next-line import/no-extraneous-dependencies import {promisify} from 'util'; +import {validateChannel} from '~/common/channel'; import Leave from '~/common/packets/leave.packet'; import {channelUserCounts} from '~/server/entry'; +import ValidationError from './validation-error'; + export const userLeave = async (channel, socket) => { const userId = '/r/anonymous' === channel ? 0 : socket.req.userId; await promisify(socket.leave.bind(socket))(channel); @@ -16,9 +19,10 @@ export const userLeave = async (channel, socket) => { export default { Packet: Leave, - validator: async () => true, - responder: async (packet, socket) => { - const {channel} = packet.data; - return userLeave(channel, socket); + validator: async () => { + if (!validateChannel()) { + throw new ValidationError({code: 400, reason: 'Malformed channel'}); + } }, + responder: async ({data: channel}, socket) => userLeave(channel, socket), }; diff --git a/src/server/packet/message.js b/src/server/packet/message.js index 34befe1..d28407f 100644 --- a/src/server/packet/message.js +++ b/src/server/packet/message.js @@ -1,19 +1,28 @@ import {v4 as uuidv4} from 'uuid'; -import {parseChannel} from '~/common/channel'; +import {parseChannel, validateChannel} from '~/common/channel'; import Message from '~/common/packets/message.packet'; import {allModels} from '~/server/models/registrar'; +import ValidationError from './validation-error'; + export default { Packet: Message, limiter: {points: 10, duration: 15}, - validator: async () => true, - responder: async (packet, socket, fn) => { + validator: async ({data: {channel, message}}) => { + if (!validateChannel(channel)) { + throw new ValidationError({code: 400, reason: 'Malformed channel'}); + } + if (message.length > 1024) { + throw new ValidationError({code: 400, reason: 'Your message was a bit too long'}); + } + }, + responder: async ({data}, socket) => { const {req} = socket; const {pubClient} = req.adapter; const {User} = allModels(); - const {channel, message} = packet.data; + const {channel, message} = data; const {name, type} = parseChannel(`/chat${channel}`); const other = await User.findOne({where: {redditUsername: name}}); const owner = '/r/anonymous' === channel ? 0 : req.userId; @@ -26,7 +35,7 @@ export default { const key = `${serverChannel}:messages:${uuid}`; ('u' === type ? [`/user/${other.id}`, `/user/${req.userId}`] : [channel]).forEach((room) => ( socket.to(room, new Message({ - ...packet.data, + ...data, channel: 'r' === type ? channel : `/u/${username === room.substr(3) ? name : username}`, diff --git a/src/server/packet/remove-favorite.js b/src/server/packet/remove-favorite.js index c31672e..9bf6f44 100644 --- a/src/server/packet/remove-favorite.js +++ b/src/server/packet/remove-favorite.js @@ -1,34 +1,37 @@ +import {validateChannel} from '~/common/channel'; import RemoveFavorite from '~/common/packets/remove-favorite.packet'; import {allModels} from '~/server/models/registrar'; +import ValidationError from './validation-error'; + export const removeFavoritedUser = async (socket, user, other) => { const {Favorite} = allModels(); - const favorites = await user.getFavorites(); - const toRemove = favorites.find(({channel}) => channel === `/u/${other.redditUsername}`); - if (toRemove) { - await Favorite.destroy({ - where: { - id: toRemove.id, - }, - }); + const favorite = await Favorite.findOne( + {where: {channel: `/u/${other.redditUsername}`, user_id: user.id}}, + ); + if (favorite) { + await Favorite.destroy({where: {id: favorite.id}}); socket.to(`/user/${user.id}`, new RemoveFavorite(`/u/${other.redditUsername}`)); } }; export default { Packet: RemoveFavorite, - validator: async () => true, + validator: async ({data: channel}, {req: {user}}) => { + const {Favorite} = allModels(); + if (!validateChannel()) { + throw new ValidationError({code: 400, reason: 'Malformed channel.'}); + } + if (0 === await Favorite.count({where: {user_id: user.id, channel}})) { + throw new ValidationError({code: 400, reason: 'No such favorite existed.'}); + } + }, responder: async (packet, socket) => { const {req} = socket; const {Favorite} = allModels(); - const favorites = await req.user.getFavorites(); - const toRemove = favorites.find(({channel}) => channel === packet.data); - await Favorite.destroy({ - where: { - id: toRemove.id, - }, - }); + const favorite = await Favorite.findOne({where: {channel: packet.data, user_id: req.user.id}}); + await Favorite.destroy({where: {id: favorite.id}}); socket.to(`/user/${req.userId}`, packet); }, }; diff --git a/src/server/packet/remove-friend.js b/src/server/packet/remove-friend.js index 0818595..7b580f2 100644 --- a/src/server/packet/remove-friend.js +++ b/src/server/packet/remove-friend.js @@ -6,12 +6,26 @@ import {allModels} from '~/server/models/registrar'; import {removeFavoritedUser} from './remove-favorite'; +import ValidationError from './validation-error'; + export default { Packet: RemoveFriend, - validator: async () => true, - responder: async (packet, socket) => { + validator: async ({data: id}, {req: {userId}}) => { + const {Friendship} = allModels(); + const hasFriendship = !!await Friendship.count({ + where: { + [Op.or]: [ + {[Op.and]: [{addeeId: userId}, {adderId: id}]}, + {[Op.and]: [{addeeId: id}, {adderId: userId}]}, + ], + }, + }); + if (!hasFriendship) { + throw new ValidationError({code: 400, reason: 'Malformed friendship.'}); + } + }, + responder: async ({data: id}, socket) => { const {req} = socket; - const id = packet.data; const {Friendship, User} = allModels(); await Friendship.destroy({ where: { @@ -24,7 +38,9 @@ export default { socket.to(`/user/${id}`, new RemoveFriend(req.userId)); socket.to(`/user/${req.userId}`, new RemoveFriend(id)); const user = await User.findByPk(id); - removeFavoritedUser(socket, user, req.user); - removeFavoritedUser(socket, req.user, user); + return Promise.all([ + removeFavoritedUser(socket, user, req.user), + removeFavoritedUser(socket, req.user, user), + ]); }, }; diff --git a/src/server/packet/unblock.js b/src/server/packet/unblock.js index 8cb85db..42a5596 100644 --- a/src/server/packet/unblock.js +++ b/src/server/packet/unblock.js @@ -2,15 +2,28 @@ import Unblock from '~/common/packets/unblock.packet'; import {allModels} from '~/server/models/registrar'; +import ValidationError from './validation-error'; + export default { Packet: Unblock, - validator: async () => true, - responder: async (packet, socket) => { + validator: async ({data: blocked}, {req: {userId}}) => { + const {Block: BlockModel} = allModels(); + const hasBlock = !!await BlockModel.count({ + where: { + blocked, + user_id: userId, + }, + }); + if (!hasBlock) { + throw new ValidationError({code: 400, reason: "Wasn't blocking."}); + } + }, + responder: async ({data: blocked}, socket) => { const {req} = socket; const {Block: BlockModel} = allModels(); await BlockModel.destroy({ where: { - blocked: packet.data, + blocked, user_id: req.userId, }, }); diff --git a/src/server/packet/validation-error.js b/src/server/packet/validation-error.js index 1e5d16e..3c2798f 100644 --- a/src/server/packet/validation-error.js +++ b/src/server/packet/validation-error.js @@ -1 +1,9 @@ -export default class ValidationError extends Error {} +export default class ValidationError extends Error { + + constructor(...args) { + const [payload, ...after] = args; + super(...after); + this.payload = payload; + } + +} diff --git a/src/server/sockets.js b/src/server/sockets.js index e91c45f..fc04517 100644 --- a/src/server/sockets.js +++ b/src/server/sockets.js @@ -10,6 +10,7 @@ import createLimiter from './limiter'; import * as PacketHandlers from './packet'; import {userJoin} from './packet/join'; import {userLeave} from './packet/leave'; +import ValidationError from './packet/validation-error'; import passport from './passport'; import createRedisClient from './redis'; import session from './session'; @@ -72,6 +73,10 @@ export function createSocketServer(httpServer) { fn(undefined, await responder(packet, socket)); } catch (error) { + if (error instanceof ValidationError) { + fn(error.payload); + return; + } if (error instanceof Error) { fn({code: 500}); throw error;