From d3822fe9336f1d19c74e56c340ea78edbd02f04f Mon Sep 17 00:00:00 2001 From: Hri7566 Date: Sun, 10 Sep 2023 00:15:34 -0400 Subject: [PATCH] Fix session problem --- src/channel/Channel.ts | 34 ++++++----- src/util/id.ts | 4 ++ src/ws/Socket.ts | 28 +++++++++- src/ws/server.ts | 124 +++++++++-------------------------------- 4 files changed, 72 insertions(+), 118 deletions(-) diff --git a/src/channel/Channel.ts b/src/channel/Channel.ts index a36f398..da5346c 100644 --- a/src/channel/Channel.ts +++ b/src/channel/Channel.ts @@ -111,21 +111,12 @@ export class Channel { let hasChangedChannel = false; - this.logger.debug("Has user?", this.hasUser(part)); + this.logger.debug("Has user?", this.hasUser(part._id)); // Is user in this channel? - if (this.hasUser(part)) { - // Alreay in channel, disconnect old - - const oldSocket = findSocketByPartID(part.id); - - if (oldSocket) { - oldSocket.destroy(); - } - - // Add to channel - this.ppl.push(part); - hasChangedChannel = true; + if (this.hasUser(part._id)) { + // Alreay in channel, add this part ID to IDs list + // TODO } else { // Are we full? if (!this.isFull()) { @@ -142,10 +133,12 @@ export class Channel { // Is user in any channel that isn't this one? for (const ch of channelList) { if (ch == this) continue; - if (ch.hasUser(part)) { + if (ch.hasUser(part._id)) { ch.leave(socket); } } + + socket.currentChannelID = this.getID(); } this.logger.debug("Participant list:", this.ppl); @@ -167,10 +160,10 @@ export class Channel { this.logger.debug("Leave called"); const part = socket.getParticipant(); - // Same as above... + // Unknown side-effects, but for type safety... if (!part) return; - if (this.hasUser(part)) { + if (this.hasUser(part._id)) { this.ppl.splice(this.ppl.indexOf(part), 1); } // TODO Broadcast channel update @@ -199,8 +192,13 @@ export class Channel { return this.ppl; } - public hasUser(part: Participant) { - const foundPart = this.ppl.find(p => p._id == part._id); + public hasUser(_id: string) { + const foundPart = this.ppl.find(p => p._id == _id); + return !!foundPart; + } + + public hasParticipant(id: string) { + const foundPart = this.ppl.find(p => p.id == id); return !!foundPart; } } diff --git a/src/util/id.ts b/src/util/id.ts index af869b3..1bbbb12 100644 --- a/src/util/id.ts +++ b/src/util/id.ts @@ -13,6 +13,10 @@ export function createUserID(ip: string) { .substring(0, 24); } +export function createSocketID() { + return crypto.randomUUID(); +} + export function createColor(ip: string) { return ( "#" + diff --git a/src/ws/Socket.ts b/src/ws/Socket.ts index 2185d4c..63886ef 100644 --- a/src/ws/Socket.ts +++ b/src/ws/Socket.ts @@ -9,6 +9,8 @@ import { loadConfig } from "../util/config"; import { Gateway } from "./Gateway"; import { Channel, channelList } from "../channel/Channel"; import { ServerWebSocket } from "bun"; +import { findSocketByUserID, socketsBySocketID } from "./server"; +import { Logger } from "../util/Logger"; interface UsersConfig { defaultName: string; @@ -22,6 +24,8 @@ const usersConfig = loadConfig("config/users.yml", { } }); +const logger = new Logger("Sockets"); + export class Socket extends EventEmitter { private id: string; private _id: string; @@ -43,11 +47,29 @@ export class Socket extends EventEmitter { constructor(private ws: ServerWebSocket) { super(); this.ip = ws.remoteAddress; // Participant ID - this.id = createID(); // User ID this._id = createUserID(this.getIP()); - // *cough* lapis + + // Check if we're already connected + // We need to skip ourselves, so we loop here instead of using a helper + let foundSocket; + + for (const socket of socketsBySocketID.values()) { + if (socket == this) continue; + + if (socket.getUserID() == this.getUserID()) { + foundSocket = socket; + } + } + + if (!foundSocket) { + // Use new session ID + this.id = createID(); + } else { + // Use original session ID + this.id = foundSocket.id; + } this.loadUser(); @@ -94,6 +116,7 @@ export class Socket extends EventEmitter { ); channel.join(this); + // TODO Give the crown upon joining } } @@ -189,6 +212,7 @@ export class Socket extends EventEmitter { const foundCh = channelList.find( ch => ch.getID() == this.currentChannelID ); + if (foundCh) foundCh.leave(this); } diff --git a/src/ws/server.ts b/src/ws/server.ts index 4f495b9..2ba00f6 100644 --- a/src/ws/server.ts +++ b/src/ws/server.ts @@ -1,14 +1,6 @@ -// import { -// App, -// DEDICATED_COMPRESSOR_8KB, -// HttpRequest, -// HttpResponse, -// WebSocket -// } from "uWebSockets.js"; import { Logger } from "../util/Logger"; -import { createUserID } from "../util/id"; +import { createSocketID, createUserID } from "../util/id"; import fs from "fs"; -// import { join } from "path"; import path from "path"; import { handleMessage } from "./message"; import { decoder } from "../util/helpers"; @@ -18,99 +10,28 @@ import env from "../util/env"; const logger = new Logger("WebSocket Server"); -const usersByPartID = new Map(); +export const socketsBySocketID = new Map(); export function findSocketByPartID(id: string) { - for (const key of usersByPartID.keys()) { - if (key == id) return usersByPartID.get(key); + for (const socket of socketsBySocketID.values()) { + if (socket.getParticipantID() == id) return socket; } } -// Original uWebSockets code -// export const app = App() -// .get("/*", async (res, req) => { -// const url = req.getUrl(); -// const ip = decoder.decode(res.getRemoteAddressAsText()); -// // logger.debug(`${req.getMethod()} ${url} ${ip}`); -// // res.writeStatus(`200 OK`).end("HI!"); -// const file = join("./public/", url); +export function findSocketByUserID(_id: string) { + for (const socket of socketsBySocketID.values()) { + logger.debug("User ID:", socket.getUserID()); + if (socket.getUserID() == _id) return socket; + } +} -// // TODO Cleaner file serving -// try { -// const stats = lstatSync(file); - -// let data; -// if (!stats.isDirectory()) { -// data = readFileSync(file); -// } - -// // logger.debug(filename); - -// if (!data) { -// const index = readFileSync("./public/index.html"); - -// if (!index) { -// return void res -// .writeStatus(`404 Not Found`) -// .end("uh oh :("); -// } else { -// return void res.writeStatus(`200 OK`).end(index); -// } -// } - -// res.writeStatus(`200 OK`).end(data); -// } catch (err) { -// logger.warn("Unable to serve file at", file); -// logger.error(err); -// const index = readFileSync("./public/index.html"); - -// if (!index) { -// return void res.writeStatus(`404 Not Found`).end("uh oh :("); -// } else { -// return void res.writeStatus(`200 OK`).end(index); -// } -// } -// }) -// .ws("/*", { -// idleTimeout: 25, -// maxBackpressure: 1024, -// maxPayloadLength: 8192, -// compression: DEDICATED_COMPRESSOR_8KB, - -// open: ((ws: WebSocket & { socket: Socket }) => { -// ws.socket = new Socket(ws); -// // logger.debug("Connection at " + ws.socket.getIP()); - -// usersByPartID.set(ws.socket.getParticipantID(), ws.socket); -// }) as (ws: WebSocket) => void, - -// message: (( -// ws: WebSocket & { socket: Socket }, -// message, -// isBinary -// ) => { -// const msg = decoder.decode(message); -// handleMessage(ws.socket, msg); -// }) as ( -// ws: WebSocket, -// message: ArrayBuffer, -// isBinary: boolean -// ) => void, - -// close: (( -// ws: WebSocket & { socket: Socket }, -// code: number, -// message: ArrayBuffer -// ) => { -// logger.debug("Close called"); -// ws.socket.destroy(); -// usersByPartID.delete(ws.socket.getParticipantID()); -// }) as ( -// ws: WebSocket, -// code: number, -// message: ArrayBuffer -// ) => void -// }); +export function findSocketByIP(ip: string) { + for (const socket of socketsBySocketID.values()) { + if (socket.getIP() == ip) { + return socket; + } + } +} export const app = Bun.serve({ port: env.PORT, @@ -147,7 +68,7 @@ export const app = Bun.serve({ (ws as unknown as any).socket = socket; logger.debug("Connection at " + socket.getIP()); - usersByPartID.set(socket.getParticipantID(), socket); + socketsBySocketID.set(createSocketID(), socket); }, message: (ws, message) => { @@ -159,7 +80,14 @@ export const app = Bun.serve({ logger.debug("Close called"); const socket = (ws as unknown as any).socket as Socket; socket.destroy(); - usersByPartID.delete(socket.getParticipantID()); + + for (const sockID of socketsBySocketID.keys()) { + const sock = socketsBySocketID.get(sockID); + + if (sock == socket) { + socketsBySocketID.delete(sockID); + } + } } } });