diff --git a/src/build.zig b/src/build.zig index bb78820..d7c8f78 100644 --- a/src/build.zig +++ b/src/build.zig @@ -49,6 +49,23 @@ pub fn writeFalse(writer: anytype) !void { try writeSimple(writer, 20); } +pub fn writeFloat(writer: anytype, f: anytype) !void { + const T = @TypeOf(f); + const TInf = @typeInfo(T); + + switch (TInf) { + .Float => |float| { + switch (float.bits) { + 16 => try cbor.encode_2(writer, 0xe0, @as(u64, @intCast(@as(u16, @bitCast(f))))), + 32 => try cbor.encode_4(writer, 0xe0, @as(u64, @intCast(@as(u32, @bitCast(f))))), + 64 => try cbor.encode_8(writer, 0xe0, @as(u64, @intCast(@as(u64, @bitCast(f))))), + else => @compileError("Float must be 16, 32 or 64 Bits wide"), + } + }, + else => return error.NotAFloat, + } +} + /// Write the header of an array to `writer`. /// /// You must write exactly `len` data items to `writer` afterwards. @@ -451,3 +468,43 @@ test "write true false" { try std.testing.expectEqual(@as(u8, 0xf5), arr.items[0]); try std.testing.expectEqual(@as(u8, 0xf4), arr.items[1]); } + +test "write float #1" { + const allocator = std.testing.allocator; + var arr = std.ArrayList(u8).init(allocator); + defer arr.deinit(); + + try writeFloat(arr.writer(), @as(f16, @floatCast(0.0))); + + try std.testing.expectEqualSlices(u8, "\xf9\x00\x00", arr.items); +} + +test "write float #2" { + const allocator = std.testing.allocator; + var arr = std.ArrayList(u8).init(allocator); + defer arr.deinit(); + + try writeFloat(arr.writer(), @as(f16, @floatCast(-0.0))); + + try std.testing.expectEqualSlices(u8, "\xf9\x80\x00", arr.items); +} + +test "write float #3" { + const allocator = std.testing.allocator; + var arr = std.ArrayList(u8).init(allocator); + defer arr.deinit(); + + try writeFloat(arr.writer(), @as(f32, @floatCast(3.4028234663852886e+38))); + + try std.testing.expectEqualSlices(u8, "\xfa\x7f\x7f\xff\xff", arr.items); +} + +test "write float #4" { + const allocator = std.testing.allocator; + var arr = std.ArrayList(u8).init(allocator); + defer arr.deinit(); + + try writeFloat(arr.writer(), @as(f64, @floatCast(-4.1))); + + try std.testing.expectEqualSlices(u8, "\xfb\xc0\x10\x66\x66\x66\x66\x66\x66", arr.items); +} diff --git a/src/cbor.zig b/src/cbor.zig index de6a2e3..ab93d7b 100644 --- a/src/cbor.zig +++ b/src/cbor.zig @@ -189,6 +189,45 @@ pub const DataItem = struct { } } + pub fn isFloat16(self: @This()) bool { + if (self.data.len < 3) return false; + return self.data[0] == 0xf9; + } + + pub fn getFloat16(self: @This()) ?f16 { + if (!self.isFloat16()) return null; + if (additionalInfo(self.data, null)) |v| { + return @floatCast(@as(f16, @bitCast(@as(u16, @intCast(v))))); + } + return null; + } + + pub fn isFloat32(self: @This()) bool { + if (self.data.len < 5) return false; + return self.data[0] == 0xfa; + } + + pub fn getFloat32(self: @This()) ?f32 { + if (!self.isFloat32()) return null; + if (additionalInfo(self.data, null)) |v| { + return @floatCast(@as(f32, @bitCast(@as(u32, @intCast(v))))); + } + return null; + } + + pub fn isFloat64(self: @This()) bool { + if (self.data.len < 9) return false; + return self.data[0] == 0xfb; + } + + pub fn getFloat64(self: @This()) ?f64 { + if (!self.isFloat64()) return null; + if (additionalInfo(self.data, null)) |v| { + return @floatCast(@as(f64, @bitCast(@as(u64, @intCast(v))))); + } + return null; + } + /// Decode the given DataItem into a Tag /// /// This function will return null if the DataItem @@ -673,6 +712,12 @@ test "deserialize float" { const di4 = try DataItem.new("\xfb\x7e\x37\xe4\x3c\x88\x00\x75\x9c"); try std.testing.expectEqual(Type.Float, di4.getType()); try std.testing.expectApproxEqAbs(di4.float().?, 1.0e+300, 0.000000001); + + try std.testing.expectEqual(@as(f16, 0.0), (try DataItem.new("\xf9\x00\x00")).getFloat16().?); + try std.testing.expectEqual(@as(f16, -0.0), (try DataItem.new("\xf9\x80\x00")).getFloat16().?); + try std.testing.expectEqual(@as(f16, 65504.0), (try DataItem.new("\xf9\x7b\xff")).getFloat16().?); + try std.testing.expectEqual(@as(f32, 3.4028234663852886e+38), (try DataItem.new("\xfa\x7f\x7f\xff\xff")).getFloat32().?); + try std.testing.expectEqual(@as(f64, -4.1), (try DataItem.new("\xfb\xc0\x10\x66\x66\x66\x66\x66\x66")).getFloat64().?); } test "deserialize tagged" {