From 327bba28156bd0d79296f36c8e9207e25b26fb73 Mon Sep 17 00:00:00 2001 From: Josh Megnauth Date: Sun, 31 Dec 2023 07:10:22 -0500 Subject: [PATCH] Fix a few memory leaks in encode strs (Zig) --- Zig/src/neet/encode_decode_strs.zig | 87 ++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 20 deletions(-) diff --git a/Zig/src/neet/encode_decode_strs.zig b/Zig/src/neet/encode_decode_strs.zig index 0a602d0..91b65e7 100644 --- a/Zig/src/neet/encode_decode_strs.zig +++ b/Zig/src/neet/encode_decode_strs.zig @@ -3,11 +3,15 @@ const fmt = std.fmt; const Allocator = std.mem.Allocator; const ArrayList = std.ArrayList; const expectEqualStrings = std.testing.expectEqualStrings; +const expectEqualSlices = std.testing.expectEqualSlices; /// Serialize a slice of strings into a single, flat string. /// /// I'm essentially using a variation of Bencoding. -pub fn serialize_slice_str(allocator: Allocator, strings: []const []const u8) ![]const u8 { +pub fn serialize_slice_str( + allocator: Allocator, + strings: []const []const u8, +) ![]const u8 { // Cumulative length of each prefix + string (total length of each serialized string) var strs_offsets = try allocator.alloc(u64, strings.len); defer allocator.destroy(strs_offsets.ptr); @@ -44,14 +48,11 @@ pub fn serialize_slice_str(allocator: Allocator, strings: []const []const u8) ![ // Strings serialization buf var ser_buf = full_buf[1..total_len]; - std.debug.print("full_buf size: {}\n", .{full_buf.len}); - std.debug.print("ser_buf size: {}\n", .{ser_buf.len}); i = 0; while (i < strings.len) : (i += 1) { const offset = strs_offsets[i]; const s = strings[i]; - std.debug.print("str size: {} and offset: {}\n", .{ s.len, offset }); try serialize_str_buf(ser_buf[offset..], s); } @@ -67,29 +68,37 @@ pub fn serialize_slice_str(allocator: Allocator, strings: []const []const u8) ![ /// Both of these must be freed. /// /// While this is technically Bencode, I'm only supporting flat lists. -pub fn deserialize_slice_str(allocator: Allocator, encoded: []const u8) !IResult([]const []const u8) { - var strings = ArrayList([]const u8).init(); - errdefer strings.deinit(); +pub fn deserialize_slice_str( + allocator: Allocator, + encoded: []const u8, +) !IResult([]const []const u8) { + var strings = ArrayList([]const u8).init(allocator); + errdefer dealloc_strings(strings); // Parse opening delimiter - var buffer = try take_n(encoded, 1); - if (!buffer.parsed[0] == 'a') { + const delim = try take_n(encoded, 1); + if (delim.parsed[0] != 'l') { return ParseError.DelimiterNotFound; } - while (buffer.remainder.len != 1) { - buffer = try deserialize_str(buffer.remainder); + var remainder = delim.remainder; + while (remainder.len != 1) { + const buffer = try deserialize_str(remainder); + remainder = buffer.remainder; // Copy string to give the caller ownership const len = buffer.value.len; const string = try allocator.alloc(u8, len); errdefer allocator.destroy(string.ptr); - @memcpy(string.ptr, buffer.value, len); + @memcpy(string.ptr, buffer.value.ptr, len); try strings.append(string); } - return strings.toOwnedSlice(); + return IResult([]const []const u8){ + .value = strings.toOwnedSlice(), + .remainder = remainder, + }; } /// Deserialize a single string. @@ -131,10 +140,17 @@ fn serialize_str_buf(buf: []u8, s: []const u8) !void { _ = try fmt.bufPrint(buf, ser_str_fmt, .{ s.len, s }); } -const ParseError = error{ DelimiterNotFound, InvalidLength, InvalidUnicode }; +const ParseError = error{ + DelimiterNotFound, + InvalidLength, + InvalidUnicode, +}; const ParseResult = - struct { parsed: []const u8, remainder: []const u8 }; + struct { + parsed: []const u8, + remainder: []const u8, +}; fn IResult(comptime T: type) type { return struct { @@ -144,7 +160,10 @@ fn IResult(comptime T: type) type { remainder: []const u8, fn new(value: T, remainder: []const u8) @This() { - return @This(){ .value = value, .remainder = remainder }; + return @This(){ + .value = value, + .remainder = remainder, + }; } }; } @@ -154,7 +173,10 @@ fn IResult(comptime T: type) type { // Returns an error if delimiter isn't found. fn take_until(s: []const u8, delimiter: []const u8) ParseError!ParseResult { if (delimiter.len == 0) { - return ParseResult{ .parsed = s, .remainder = comptime &[0]u8{} }; + return ParseResult{ + .parsed = s, + .remainder = comptime &[0]u8{}, + }; } // Accumulate everything before the delimiter. @@ -169,7 +191,10 @@ fn take_until(s: []const u8, delimiter: []const u8) ParseError!ParseResult { } // Delimiter found if the loop traversed the length of delimiter if (j == delimiter.len) { - return ParseResult{ .parsed = s[0..i], .remainder = s[i + j .. s.len] }; + return ParseResult{ + .parsed = s[0..i], + .remainder = s[i + j .. s.len], + }; } } @@ -181,7 +206,18 @@ fn take_n(bytes: []const u8, amount: usize) !ParseResult { return ParseError.InvalidLength; } - return ParseResult{ .parsed = bytes[0..amount], .remainder = bytes[amount..] }; + return ParseResult{ + .parsed = bytes[0..amount], + .remainder = bytes[amount..], + }; +} + +fn dealloc_strings(strings: ArrayList([]const u8)) void { + for (strings.items) |str| { + strings.allocator.destroy(str.ptr); + } + + strings.deinit(); } test "serialize strings simple" { @@ -221,10 +257,21 @@ test "deserialize empty string slice" { } test "deserialized string slice correctly owns memory" { - const strings = "l4:meowe"; + const expected = [1][]const u8{"meow"}; + const strings = fmt.comptimePrint("l{}:{s}e", .{ expected.len, expected }); // Deserialization will borrow from `owned` and copy the string into an owned buffer // Freeing `owned` shouldn't cause a use after free const owned = try std.testing.allocator.alloc(u8, strings.len); + errdefer std.testing.allocator.destroy(owned.ptr); @memcpy(owned.ptr, strings, strings.len); + + // Deserialize from the allocated heap buffer then free the buf + const actual = try deserialize_slice_str(std.testing.allocator, strings); + defer std.testing.allocator.destroy(actual.value.ptr); + std.testing.allocator.destroy(owned.ptr); + + const expected_remainder = ""; + try expectEqualStrings(expected_remainder, actual.remainder); + try expectEqualSlices([]const u8, &expected, actual.value); }