diff --git a/lib/network/handler.dart b/lib/network/handler.dart index e3ba3db3..a1d9ec75 100644 --- a/lib/network/handler.dart +++ b/lib/network/handler.dart @@ -329,11 +329,7 @@ class WebSocketChannelHandler extends ChannelHandler { void channelRead(ChannelContext channelContext, Channel channel, Uint8List msg) { proxyChannel.write(msg); WebSocketFrame? frame; - try { - frame = decoder.decode(msg); - } catch (e) { - log.e("websocket decode error", error: e); - } + frame = decoder.decode(msg); if (frame == null) { return; } diff --git a/lib/network/http/websocket.dart b/lib/network/http/websocket.dart index f79fafb6..3d53d120 100644 --- a/lib/network/http/websocket.dart +++ b/lib/network/http/websocket.dart @@ -19,6 +19,8 @@ import 'dart:io'; import 'dart:math'; import 'dart:typed_data'; +import '../util/logger.dart'; + class WebSocketFrame { final bool fin; @@ -67,92 +69,138 @@ class WebSocketFrame { } } +class WebSocketHeader { + final bool fin; // 当前帧是否结束 + final int rsv; // 是否压缩 + final int opcode; //4bit 当前帧类型 + final bool mask; //1bit 是否有掩码 + final int maskingKey; // 掩码 + final int payloadStart; // 消息体起始位置 + final int payloadLength; // 消息体大小 + + WebSocketHeader({ + required this.fin, + required this.rsv, + required this.opcode, + required this.mask, + required this.maskingKey, + required this.payloadStart, + required this.payloadLength, + }); +} + ///websocket 解码器 class WebSocketDecoder { + List buffer = []; // 单独创建一个缓存,以解决数据帧不完整的问题 + WebSocketFrame? decode(Uint8List byteBuf) { - var frame = _parseWebSocketFrame(byteBuf); + WebSocketFrame? frame; + buffer.addAll(byteBuf); // 所有的数据都从缓存中读取 + try { + //先解析WebSocket Header + if (buffer.length < 2) { + // logger.w("报文缓存中的大小不够,无法解析Socket头 => ${buffer.length}"); + return null; + } + var reader = ByteData.sublistView(Uint8List.fromList(buffer)); + WebSocketHeader frameHeader = parseWebSocketHeader(reader); + // 大小不足时不解析Frame + if (buffer.length < frameHeader.payloadStart + frameHeader.payloadLength) { + return null; + } + // 处理报文 + var payloadData = Uint8List.fromList(buffer).sublist( + frameHeader.payloadStart, + frameHeader.payloadStart + frameHeader.payloadLength); + // 先解掩码 + if (frameHeader.mask) { + payloadData = unmaskPayload(payloadData, frameHeader.maskingKey); + } + // 再解压 + if (frameHeader.rsv == 1) { + payloadData = decompress(payloadData); + } + // 构建Frame + frame = WebSocketFrame( + fin: frameHeader.fin, + opcode: frameHeader.opcode, + mask: frameHeader.mask, + payloadLength: frameHeader.payloadLength, + maskingKey: frameHeader.maskingKey, + payloadData: payloadData); + // 整理buffer + buffer = + buffer.sublist(frameHeader.payloadStart + frameHeader.payloadLength); + } catch (e, s) { + logger.e("websocket decode error", error: e, stackTrace: s); + } return frame; } - bool canParseWebSocketFrame(Uint8List data) { - if (data.length < 2) { - return false; - } - - var reader = ByteData.sublistView(data); + WebSocketHeader parseWebSocketHeader(ByteData reader) { + var fin = reader.getUint8(0) >> 7; + //解析 rsv1 todo - 待支持rsv2,rsv3 + var rsv1 = (reader.getUint8(0) >> 6) & 0x01; var opcode = reader.getUint8(0) & 0x0f; - if (opcode > 0xA) { - return false; - } var mask = reader.getUint8(1) >> 7; - int payloadStart = 2; - if (mask == 1) { - payloadStart += 4; - } var payloadLength = reader.getUint8(1) & 0x7f; + + int payloadStart = 2; + if (payloadLength == 126) { + payloadLength = reader.getUint16(2, Endian.big); payloadStart += 2; } else if (payloadLength == 127) { + payloadLength = reader.getUint64(2, Endian.big); payloadStart += 8; } + var maskingKey = 0; + if (mask == 1) { + maskingKey = reader.getUint32(payloadStart); + payloadStart += 4; + } + return WebSocketHeader( + fin: fin == 1, + rsv: rsv1, + opcode: opcode, + mask: mask == 1, + maskingKey: maskingKey, + payloadStart: payloadStart, + payloadLength: payloadLength); + } - if (data.length < payloadStart + payloadLength) { + bool canParseWebSocketFrame(Uint8List data) { + if (data.length < 2) { return false; } - return true; - } - WebSocketFrame _parseWebSocketFrame(Uint8List data) { var reader = ByteData.sublistView(data); - var fin = reader.getUint8(0) >> 7; - //解析 rsv1 - var rsv1 = (reader.getUint8(0) >> 6) & 0x01; - var opcode = reader.getUint8(0) & 0x0f; + if (opcode > 0xA) { + return false; // opCode超出范围说明是异常的 + } var mask = reader.getUint8(1) >> 7; - - var payloadLength = reader.getUint8(1) & 0x7f; - int payloadStart = 2; + if (mask == 1) { + payloadStart += 4; + } + var payloadLength = reader.getUint8(1) & 0x7f; if (payloadLength == 126) { - payloadLength = reader.getUint16(2); payloadStart += 2; } else if (payloadLength == 127) { - payloadLength = reader.getUint64(2); payloadStart += 8; } - var maskingKey = 0; - if (mask == 1) { - maskingKey = reader.getUint32(payloadStart); - payloadStart += 4; - } - - var payloadData = data.sublist(payloadStart, min(payloadStart + payloadLength, data.length)); - - //根据maskKey解密内容 - if (mask == 1) { - payloadData = unmaskPayload(payloadData, maskingKey); - } - - if (rsv1 == 1) { - //inflate - payloadData = decompress(payloadData); + if (data.length < payloadStart + payloadLength) { + return false; } - return WebSocketFrame( - fin: fin == 1, - opcode: opcode, - mask: mask == 1, - maskingKey: maskingKey, - payloadLength: payloadLength, - payloadData: payloadData, - ); + return true; } ZLibDecoder? _decoder; @@ -161,6 +209,7 @@ class WebSocketDecoder { Uint8List decompress(Uint8List msg) { try { + // todo - 这个方法没有办法正确解析rsv1为1的payload. return Uint8List.fromList(_ensureDecoder().convert(msg)); } catch (e) { return msg;