diff --git a/src/reactive-rpc/server/ws/server/WsServerConnection.ts b/src/reactive-rpc/server/ws/server/WsServerConnection.ts index 11fbf6bd4b..0e842e7c74 100644 --- a/src/reactive-rpc/server/ws/server/WsServerConnection.ts +++ b/src/reactive-rpc/server/ws/server/WsServerConnection.ts @@ -1,73 +1,104 @@ -import * as net from 'net'; import * as crypto from 'crypto'; +import * as stream from 'stream'; import {WsCloseFrame, WsFrameDecoder, WsFrameHeader, WsFrameOpcode, WsPingFrame, WsPongFrame} from '../codec'; import {utf8Size} from '../../../../util/strings/utf8'; -import {FanOut} from 'thingies/es2020/fanout'; +import {listToUint8} from '../../../../util/buffers/concat'; import type {WsFrameEncoder} from '../codec/WsFrameEncoder'; +export type WsServerConnectionSocket = stream.Duplex; + export class WsServerConnection { public closed: boolean = false; public maxIncomingMessage: number = 2 * 1024 * 1024; public maxBackpressure: number = 2 * 1024 * 1024; - /** - * If this is not null, then the connection is receiving a stream: a sequence - * of fragment frames. - */ - protected stream: FanOut | null = null; - public readonly defaultOnPing = (data: Uint8Array | null): void => { this.sendPong(data); }; + private _fragments: Uint8Array[] = []; + private _fragmentsSize: number = 0; + public readonly defaultOnFragment = (isLast: boolean, data: Uint8Array, isUtf8: boolean): void => { + const fragments = this._fragments; + this._fragmentsSize += data.length; + if (this._fragmentsSize > this.maxIncomingMessage) { + this.onClose(1009, 'TOO_LARGE'); + return; + } + fragments.push(data); + if (!isLast) return; + this._fragments = []; + this._fragmentsSize = 0; + const message = listToUint8(fragments); + this.onmessage(message, isUtf8); + }; + public onmessage: (data: Uint8Array, isUtf8: boolean) => void = () => {}; + public onfragment: (isLast: boolean, data: Uint8Array, isUtf8: boolean) => void = this.defaultOnFragment; public onping: (data: Uint8Array | null) => void = this.defaultOnPing; public onpong: (data: Uint8Array | null) => void = () => {}; public onclose: (code: number, reason: string) => void = () => {}; - constructor(protected readonly encoder: WsFrameEncoder, public readonly socket: net.Socket) { + constructor(protected readonly encoder: WsFrameEncoder, public readonly socket: WsServerConnectionSocket) { const decoder = new WsFrameDecoder(); - let currentFrame: WsFrameHeader | null = null; + let currentFrameHeader: WsFrameHeader | null = null; + let fragmentStartFrameHeader: WsFrameHeader | null = null; const handleData = (data: Uint8Array): void => { try { decoder.push(data); - if (currentFrame) { - const length = currentFrame.length; - if (length <= decoder.reader.size()) { - const buf = new Uint8Array(length); - decoder.copyFrameData(currentFrame, buf, 0); - const isText = currentFrame.opcode === WsFrameOpcode.TEXT; - currentFrame = null; - this.onmessage(buf, isText); + main: while (true) { + if (currentFrameHeader instanceof WsFrameHeader) { + const length = currentFrameHeader.length; + if (length > this.maxIncomingMessage) { + this.onClose(1009, 'TOO_LARGE'); + return; + } + if (length <= decoder.reader.size()) { + const buf = new Uint8Array(length); + decoder.copyFrameData(currentFrameHeader, buf, 0); + if (fragmentStartFrameHeader instanceof WsFrameHeader) { + const isText = fragmentStartFrameHeader.opcode === WsFrameOpcode.TEXT; + const isLast = currentFrameHeader.fin === 1; + currentFrameHeader = null; + if (isLast) fragmentStartFrameHeader = null; + this.onfragment(isLast, buf, isText); + } else { + const isText = currentFrameHeader.opcode === WsFrameOpcode.TEXT; + currentFrameHeader = null; + this.onmessage(buf, isText); + } + } else break; } - } - while (true) { const frame = decoder.readFrameHeader(); if (!frame) break; - else if (frame instanceof WsPingFrame) this.onping(frame.data); - else if (frame instanceof WsPongFrame) this.onpong(frame.data); - else if (frame instanceof WsCloseFrame) this.onClose(frame.code, frame.reason); - else if (frame instanceof WsFrameHeader) { - if (this.stream) { + if (frame instanceof WsPingFrame) { + this.onping(frame.data); + continue main; + } + if (frame instanceof WsPongFrame) { + this.onpong(frame.data); + continue main; + } + if (frame instanceof WsCloseFrame) { + decoder.readCloseFrameData(frame); + this.onClose(frame.code, frame.reason); + continue main; + } + if (frame instanceof WsFrameHeader) { + if (fragmentStartFrameHeader) { if (frame.opcode !== WsFrameOpcode.CONTINUE) { this.onClose(1002, 'DATA'); return; } - throw new Error('streaming not implemented'); + currentFrameHeader = frame; } - const length = frame.length; - if (length > this.maxIncomingMessage) { - this.onClose(1009, 'TOO_LARGE'); - return; - } - if (length <= decoder.reader.size()) { - const buf = new Uint8Array(length); - decoder.copyFrameData(frame, buf, 0); - const isText = frame.opcode === WsFrameOpcode.TEXT; - this.onmessage(buf, isText); - } else { - currentFrame = frame; + if (frame.fin === 0) { + fragmentStartFrameHeader = frame; + currentFrameHeader = frame; + continue main; } + currentFrameHeader = frame; + continue main; } } } catch (error) { diff --git a/src/reactive-rpc/server/ws/server/__tests__/WsServerConnection.spec.ts b/src/reactive-rpc/server/ws/server/__tests__/WsServerConnection.spec.ts new file mode 100644 index 0000000000..4b50147b73 --- /dev/null +++ b/src/reactive-rpc/server/ws/server/__tests__/WsServerConnection.spec.ts @@ -0,0 +1,287 @@ +import * as stream from 'stream'; +import {WsServerConnection} from '../WsServerConnection'; +import {WsFrameEncoder} from '../../codec/WsFrameEncoder'; +import {until} from 'thingies'; +import {WsFrameOpcode} from '../../codec'; +import {bufferToUint8Array} from '../../../../../util/buffers/bufferToUint8Array'; +import {listToUint8} from '../../../../../util/buffers/concat'; + +const setup = () => { + const socket = new stream.PassThrough(); + const encoder = new WsFrameEncoder(); + const connection = new WsServerConnection(encoder, socket); + return {socket, encoder, connection}; +}; + +describe('.onping', () => { + test('can parse PING frame', async () => { + const {socket, encoder, connection} = setup(); + const pings: Uint8Array[] = []; + connection.onping = (data: Uint8Array | null): void => { + if (data) pings.push(data); + }; + const pingFrame = encoder.encodePing(Buffer.from([0x01, 0x02, 0x03])); + socket.write(pingFrame); + await until(() => pings.length === 1); + expect(pings[0]).toEqual(new Uint8Array([0x01, 0x02, 0x03])); + }); + + test('can parse empty PING frame', async () => { + const {socket, encoder, connection} = setup(); + const pings: Uint8Array[] = []; + connection.onping = (data: Uint8Array | null): void => { + if (data) pings.push(data); + }; + const pingFrame = encoder.encodePing(Buffer.from([0x01, 0x02, 0x03])); + socket.write(pingFrame); + const pingFrame2 = encoder.encodePing(Buffer.from([])); + socket.write(pingFrame2); + await until(() => pings.length === 2); + expect(pings[0]).toEqual(new Uint8Array([0x01, 0x02, 0x03])); + expect(pings[1]).toEqual(new Uint8Array([])); + }); +}); + +describe('.onping', () => { + test('can parse PONG frame', async () => { + const {socket, encoder, connection} = setup(); + const pongs: Uint8Array[] = []; + connection.onpong = (data: Uint8Array | null): void => { + if (data) pongs.push(data); + }; + const pingFrame = encoder.encodePong(Buffer.from([0x01, 0x02, 0x03])); + socket.write(pingFrame); + await until(() => pongs.length === 1); + expect(pongs[0]).toEqual(new Uint8Array([0x01, 0x02, 0x03])); + }); +}); + +describe('.onclose', () => { + test('can parse CLOSE frame', async () => { + const {socket, encoder, connection} = setup(); + const closes: [code: number, reason: string][] = []; + connection.onclose = (code: number, reason: string): void => { + closes.push([code, reason]); + }; + const pingFrame = encoder.encodePong(Buffer.from([0x01, 0x02, 0x03])); + socket.write(pingFrame); + const closeFrame = encoder.encodeClose('OK', 1000); + socket.write(closeFrame); + await until(() => closes.length === 1); + expect(closes[0]).toEqual([1000, 'OK']); + }); +}); + +describe('.onmessage', () => { + describe('un-masked', () => { + test('binary data frame', async () => { + const {socket, encoder, connection} = setup(); + const messages: [data: Uint8Array, isUtf8: boolean][] = []; + connection.onmessage = (data: Uint8Array, isUtf8: boolean): void => { + messages.push([data, isUtf8]); + }; + const pingFrame = encoder.encodePong(Buffer.from([0x01, 0x02, 0x03])); + const frame = encoder.encodeHdr(1, WsFrameOpcode.BINARY, 3, 0); + encoder.writer.buf(Buffer.from([0x01, 0x02, 0x03]), 3); + const payload = encoder.writer.flush(); + socket.write(pingFrame); + socket.write(frame); + socket.write(payload); + await until(() => messages.length === 1); + expect(messages[0]).toEqual([new Uint8Array([0x01, 0x02, 0x03]), false]); + }); + + test('two binary data frames', async () => { + const {socket, encoder, connection} = setup(); + const messages: [data: Uint8Array, isUtf8: boolean][] = []; + connection.onmessage = (data: Uint8Array, isUtf8: boolean): void => { + messages.push([data, isUtf8]); + }; + const pingFrame = encoder.encodePong(Buffer.from([0x01, 0x02, 0x03])); + const frame1 = encoder.encodeHdr(1, WsFrameOpcode.BINARY, 3, 0); + encoder.writer.buf(Buffer.from([0x01, 0x02, 0x03]), 3); + const payload1 = encoder.writer.flush(); + const frame2 = encoder.encodeHdr(1, WsFrameOpcode.BINARY, 3, 0); + encoder.writer.buf(Buffer.from([0x04, 0x05, 0x06]), 3); + const payload2 = encoder.writer.flush(); + socket.write(pingFrame); + socket.write(listToUint8([frame1, payload1, frame2, payload2])); + await until(() => messages.length === 2); + expect(messages[0]).toEqual([new Uint8Array([0x01, 0x02, 0x03]), false]); + expect(messages[1]).toEqual([new Uint8Array([0x04, 0x05, 0x06]), false]); + }); + + test('errors when incoming message is too large', async () => { + const {socket, encoder, connection} = setup(); + connection.maxIncomingMessage = 3; + const messages: [data: Uint8Array, isUtf8: boolean][] = []; + connection.onmessage = (data: Uint8Array, isUtf8: boolean): void => { + messages.push([data, isUtf8]); + }; + const closes: [code: number, reason: string][] = []; + connection.onclose = (code: number, reason: string): void => { + closes.push([code, reason]); + }; + const pingFrame = encoder.encodePong(Buffer.from([0x01, 0x02, 0x03])); + const frame1 = encoder.encodeHdr(1, WsFrameOpcode.BINARY, 3, 0); + encoder.writer.buf(Buffer.from([0x01, 0x02, 0x03]), 3); + const payload1 = encoder.writer.flush(); + const frame2 = encoder.encodeHdr(1, WsFrameOpcode.BINARY, 4, 0); + encoder.writer.buf(Buffer.from([0x04, 0x05, 0x06, 0x07]), 4); + const payload2 = encoder.writer.flush(); + socket.write(pingFrame); + socket.write(listToUint8([frame1, payload1, frame2, payload2])); + await until(() => messages.length === 1); + await until(() => closes.length === 1); + expect(messages[0]).toEqual([new Uint8Array([0x01, 0x02, 0x03]), false]); + expect(closes[0]).toEqual([1009, 'TOO_LARGE']); + }); + + test('text frame', async () => { + const {socket, encoder, connection} = setup(); + const messages: [data: Uint8Array, isUtf8: boolean][] = []; + connection.onmessage = (data: Uint8Array, isUtf8: boolean): void => { + messages.push([data, isUtf8]); + }; + const pingFrame1 = encoder.encodePing(Buffer.from([0x01])); + const pingFrame2 = encoder.encodePing(Buffer.from([0x01, 0x02, 0x03])); + const closeFrame = encoder.encodeHdr(1, WsFrameOpcode.TEXT, 4, 0); + encoder.writer.buf(Buffer.from('asdf'), 4); + const payload = encoder.writer.flush(); + socket.write(pingFrame1); + socket.write(pingFrame2); + socket.write(closeFrame); + socket.write(payload); + await until(() => messages.length === 1); + expect(messages[0]).toEqual([bufferToUint8Array(Buffer.from('asdf')), true]); + }); + }); + + describe('masked', () => { + test('binary data frame', async () => { + const {socket, encoder, connection} = setup(); + const messages: [data: Uint8Array, isUtf8: boolean][] = []; + connection.onmessage = (data: Uint8Array, isUtf8: boolean): void => { + messages.push([data, isUtf8]); + }; + const pingFrame = encoder.encodePong(Buffer.from([0x01, 0x02, 0x03])); + const closeFrame = encoder.encodeHdr(1, WsFrameOpcode.BINARY, 3, 0x12345678); + encoder.writeBufXor(Buffer.from([0x01, 0x02, 0x03]), 0x12345678); + const payload = encoder.writer.flush(); + socket.write(pingFrame); + socket.write(closeFrame); + socket.write(payload); + await until(() => messages.length === 1); + expect(messages[0]).toEqual([new Uint8Array([0x01, 0x02, 0x03]), false]); + }); + + test('text frame', async () => { + const {socket, encoder, connection} = setup(); + const messages: [data: Uint8Array, isUtf8: boolean][] = []; + connection.onmessage = (data: Uint8Array, isUtf8: boolean): void => { + messages.push([data, isUtf8]); + }; + const pingFrame1 = encoder.encodePing(Buffer.from([0x01])); + const pingFrame2 = encoder.encodePing(Buffer.from([0x01, 0x02, 0x03])); + const closeFrame = encoder.encodeHdr(1, WsFrameOpcode.TEXT, 4, 0x12345678); + encoder.writeBufXor(Buffer.from('asdf'), 0x12345678); + const payload = encoder.writer.flush(); + socket.write(pingFrame1); + socket.write(pingFrame2); + socket.write(closeFrame); + socket.write(payload); + await until(() => messages.length === 1); + expect(messages[0]).toEqual([bufferToUint8Array(Buffer.from('asdf')), true]); + }); + }); +}); + +describe('.onfragment', () => { + test('parses out message fragments', async () => { + const {socket, encoder, connection} = setup(); + const fragments: [isLast: boolean, data: Uint8Array, isUtf8: boolean][] = []; + connection.onfragment = (isLast: boolean, data: Uint8Array, isUtf8: boolean): void => { + fragments.push([isLast, data, isUtf8]); + }; + const pingFrame = encoder.encodePong(Buffer.from([0x01, 0x02, 0x03])); + const buf1 = encoder.encodeHdr(0, WsFrameOpcode.BINARY, 3, 0); + encoder.writer.buf(Buffer.from([0x01, 0x02, 0x03]), 3); + const buf2 = encoder.writer.flush(); + const buf3 = encoder.encodeHdr(0, WsFrameOpcode.CONTINUE, 3, 0); + encoder.writer.buf(Buffer.from([0x04, 0x05, 0x06]), 3); + const buf4 = encoder.writer.flush(); + const buf5 = encoder.encodeHdr(1, WsFrameOpcode.CONTINUE, 3, 0); + encoder.writer.buf(Buffer.from([0x07, 0x08, 0x09]), 3); + const buf6 = encoder.writer.flush(); + socket.write(pingFrame); + socket.write(buf1); + socket.write(buf2); + socket.write(buf3); + socket.write(buf4); + socket.write(buf5); + socket.write(buf6); + await until(() => fragments.length === 3); + expect(fragments).toEqual([ + [false, new Uint8Array([0x01, 0x02, 0x03]), false], + [false, new Uint8Array([0x04, 0x05, 0x06]), false], + [true, new Uint8Array([0x07, 0x08, 0x09]), false], + ]); + }); + + describe('when .onfragment is not defined', () => { + test('emits an .onmessage with fully assembled message', async () => { + const {socket, encoder, connection} = setup(); + const messages: [data: Uint8Array, isUtf8: boolean][] = []; + connection.onmessage = (data: Uint8Array, isUtf8: boolean): void => { + messages.push([data, isUtf8]); + }; + const pingFrame = encoder.encodePong(Buffer.from([0x01, 0x02, 0x03])); + const buf1 = encoder.encodeHdr(0, WsFrameOpcode.BINARY, 3, 0); + encoder.writer.buf(Buffer.from([0x01, 0x02, 0x03]), 3); + const buf2 = encoder.writer.flush(); + const buf3 = encoder.encodeHdr(0, WsFrameOpcode.CONTINUE, 3, 0); + encoder.writer.buf(Buffer.from([0x04, 0x05, 0x06]), 3); + const buf4 = encoder.writer.flush(); + const buf5 = encoder.encodeHdr(1, WsFrameOpcode.CONTINUE, 3, 0); + encoder.writer.buf(Buffer.from([0x07, 0x08, 0x09]), 3); + const buf6 = encoder.writer.flush(); + socket.write(pingFrame); + socket.write(buf1); + socket.write(buf2); + socket.write(buf3); + socket.write(buf4); + socket.write(buf5); + socket.write(buf6); + await until(() => messages.length === 1); + expect(messages).toEqual([[new Uint8Array([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09]), false]]); + }); + + test('errors out when incoming message is too large', async () => { + const {socket, encoder, connection} = setup(); + connection.maxIncomingMessage = 8; + const closes: [code: number, reason: string][] = []; + connection.onclose = (code: number, reason: string): void => { + closes.push([code, reason]); + }; + const pingFrame = encoder.encodePong(Buffer.from([0x01, 0x02, 0x03])); + const buf1 = encoder.encodeHdr(0, WsFrameOpcode.BINARY, 3, 0); + encoder.writer.buf(Buffer.from([0x01, 0x02, 0x03]), 3); + const buf2 = encoder.writer.flush(); + const buf3 = encoder.encodeHdr(0, WsFrameOpcode.CONTINUE, 3, 0); + encoder.writer.buf(Buffer.from([0x04, 0x05, 0x06]), 3); + const buf4 = encoder.writer.flush(); + const buf5 = encoder.encodeHdr(1, WsFrameOpcode.CONTINUE, 3, 0); + encoder.writer.buf(Buffer.from([0x07, 0x08, 0x09]), 3); + const buf6 = encoder.writer.flush(); + socket.write(pingFrame); + socket.write(buf1); + socket.write(buf2); + socket.write(buf3); + socket.write(buf4); + socket.write(buf5); + socket.write(buf6); + await until(() => closes.length === 1); + expect(closes).toEqual([[1009, 'TOO_LARGE']]); + }); + }); +});