Skip to content

Commit

Permalink
Fix a few memory leaks in encode strs (Zig)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuamegnauth54 committed Dec 31, 2023
1 parent c55906e commit 327bba2
Showing 1 changed file with 67 additions and 20 deletions.
87 changes: 67 additions & 20 deletions Zig/src/neet/encode_decode_strs.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ const fmt = std.fmt;
const Allocator = std.mem.Allocator;
const ArrayList = std.ArrayList;
const expectEqualStrings = std.testing.expectEqualStrings;
const expectEqualSlices = std.testing.expectEqualSlices;

/// Serialize a slice of strings into a single, flat string.
///
/// I'm essentially using a variation of Bencoding.
pub fn serialize_slice_str(allocator: Allocator, strings: []const []const u8) ![]const u8 {
pub fn serialize_slice_str(
allocator: Allocator,
strings: []const []const u8,
) ![]const u8 {
// Cumulative length of each prefix + string (total length of each serialized string)
var strs_offsets = try allocator.alloc(u64, strings.len);
defer allocator.destroy(strs_offsets.ptr);
Expand Down Expand Up @@ -44,14 +48,11 @@ pub fn serialize_slice_str(allocator: Allocator, strings: []const []const u8) ![

// Strings serialization buf
var ser_buf = full_buf[1..total_len];
std.debug.print("full_buf size: {}\n", .{full_buf.len});
std.debug.print("ser_buf size: {}\n", .{ser_buf.len});

i = 0;
while (i < strings.len) : (i += 1) {
const offset = strs_offsets[i];
const s = strings[i];
std.debug.print("str size: {} and offset: {}\n", .{ s.len, offset });
try serialize_str_buf(ser_buf[offset..], s);
}

Expand All @@ -67,29 +68,37 @@ pub fn serialize_slice_str(allocator: Allocator, strings: []const []const u8) ![
/// Both of these must be freed.
///
/// While this is technically Bencode, I'm only supporting flat lists.
pub fn deserialize_slice_str(allocator: Allocator, encoded: []const u8) !IResult([]const []const u8) {
var strings = ArrayList([]const u8).init();
errdefer strings.deinit();
pub fn deserialize_slice_str(
allocator: Allocator,
encoded: []const u8,
) !IResult([]const []const u8) {
var strings = ArrayList([]const u8).init(allocator);
errdefer dealloc_strings(strings);

// Parse opening delimiter
var buffer = try take_n(encoded, 1);
if (!buffer.parsed[0] == 'a') {
const delim = try take_n(encoded, 1);
if (delim.parsed[0] != 'l') {
return ParseError.DelimiterNotFound;
}

while (buffer.remainder.len != 1) {
buffer = try deserialize_str(buffer.remainder);
var remainder = delim.remainder;
while (remainder.len != 1) {
const buffer = try deserialize_str(remainder);
remainder = buffer.remainder;

// Copy string to give the caller ownership
const len = buffer.value.len;
const string = try allocator.alloc(u8, len);
errdefer allocator.destroy(string.ptr);

@memcpy(string.ptr, buffer.value, len);
@memcpy(string.ptr, buffer.value.ptr, len);
try strings.append(string);
}

return strings.toOwnedSlice();
return IResult([]const []const u8){
.value = strings.toOwnedSlice(),
.remainder = remainder,
};
}

/// Deserialize a single string.
Expand Down Expand Up @@ -131,10 +140,17 @@ fn serialize_str_buf(buf: []u8, s: []const u8) !void {
_ = try fmt.bufPrint(buf, ser_str_fmt, .{ s.len, s });
}

const ParseError = error{ DelimiterNotFound, InvalidLength, InvalidUnicode };
const ParseError = error{
DelimiterNotFound,
InvalidLength,
InvalidUnicode,
};

const ParseResult =
struct { parsed: []const u8, remainder: []const u8 };
struct {
parsed: []const u8,
remainder: []const u8,
};

fn IResult(comptime T: type) type {
return struct {
Expand All @@ -144,7 +160,10 @@ fn IResult(comptime T: type) type {
remainder: []const u8,

fn new(value: T, remainder: []const u8) @This() {
return @This(){ .value = value, .remainder = remainder };
return @This(){
.value = value,
.remainder = remainder,
};
}
};
}
Expand All @@ -154,7 +173,10 @@ fn IResult(comptime T: type) type {
// Returns an error if delimiter isn't found.
fn take_until(s: []const u8, delimiter: []const u8) ParseError!ParseResult {
if (delimiter.len == 0) {
return ParseResult{ .parsed = s, .remainder = comptime &[0]u8{} };
return ParseResult{
.parsed = s,
.remainder = comptime &[0]u8{},
};
}

// Accumulate everything before the delimiter.
Expand All @@ -169,7 +191,10 @@ fn take_until(s: []const u8, delimiter: []const u8) ParseError!ParseResult {
}
// Delimiter found if the loop traversed the length of delimiter
if (j == delimiter.len) {
return ParseResult{ .parsed = s[0..i], .remainder = s[i + j .. s.len] };
return ParseResult{
.parsed = s[0..i],
.remainder = s[i + j .. s.len],
};
}
}

Expand All @@ -181,7 +206,18 @@ fn take_n(bytes: []const u8, amount: usize) !ParseResult {
return ParseError.InvalidLength;
}

return ParseResult{ .parsed = bytes[0..amount], .remainder = bytes[amount..] };
return ParseResult{
.parsed = bytes[0..amount],
.remainder = bytes[amount..],
};
}

fn dealloc_strings(strings: ArrayList([]const u8)) void {
for (strings.items) |str| {
strings.allocator.destroy(str.ptr);
}

strings.deinit();
}

test "serialize strings simple" {
Expand Down Expand Up @@ -221,10 +257,21 @@ test "deserialize empty string slice" {
}

test "deserialized string slice correctly owns memory" {
const strings = "l4:meowe";
const expected = [1][]const u8{"meow"};
const strings = fmt.comptimePrint("l{}:{s}e", .{ expected.len, expected });

// Deserialization will borrow from `owned` and copy the string into an owned buffer
// Freeing `owned` shouldn't cause a use after free
const owned = try std.testing.allocator.alloc(u8, strings.len);
errdefer std.testing.allocator.destroy(owned.ptr);
@memcpy(owned.ptr, strings, strings.len);

// Deserialize from the allocated heap buffer then free the buf
const actual = try deserialize_slice_str(std.testing.allocator, strings);
defer std.testing.allocator.destroy(actual.value.ptr);
std.testing.allocator.destroy(owned.ptr);

const expected_remainder = "";
try expectEqualStrings(expected_remainder, actual.remainder);
try expectEqualSlices([]const u8, &expected, actual.value);
}

0 comments on commit 327bba2

Please sign in to comment.