diff --git a/Sources/BinaryCodable/Common/AnyCodingKey.swift b/Sources/BinaryCodable/Common/AnyCodingKey.swift new file mode 100644 index 0000000..7912907 --- /dev/null +++ b/Sources/BinaryCodable/Common/AnyCodingKey.swift @@ -0,0 +1,35 @@ + +/// Helpful for appending mixed coding keys. +struct AnyCodingKey: CodingKey { + var stringValue: String + var intValue: Int? + + init(intValue: Int) { + self.intValue = intValue + self.stringValue = "\(intValue)" + } + + init(stringValue: String) { + self.intValue = nil + self.stringValue = stringValue + } +} + +extension AnyCodingKey { + init(_ key: T) { + self.stringValue = key.stringValue + self.intValue = key.intValue + } +} + +extension AnyCodingKey: ExpressibleByStringLiteral { + init(stringLiteral value: String) { + self.init(stringValue: value) + } +} + +extension AnyCodingKey: ExpressibleByIntegerLiteral { + init(integerLiteral value: Int) { + self.init(intValue: value) + } +} diff --git a/Sources/BinaryCodable/Decoding/KeyedDecoder.swift b/Sources/BinaryCodable/Decoding/KeyedDecoder.swift index 2d93b06..76d7734 100644 --- a/Sources/BinaryCodable/Decoding/KeyedDecoder.swift +++ b/Sources/BinaryCodable/Decoding/KeyedDecoder.swift @@ -10,12 +10,33 @@ final class KeyedDecoder: AbstractDecodingNode, KeyedDecodingContainerProto while decoder.hasMoreBytes { let (key, dataType) = try DecodingKey.decode(from: decoder, path: path) - let data = try decoder.getData(for: dataType, path: path) - - guard content[key] != nil else { - content[key] = [data] - continue + do { + let data = try decoder.getData(for: dataType, path: path) + guard content[key] != nil else { + content[key] = [data] + continue + } + } catch DecodingError.dataCorrupted(let context) { + let codingKey = { + switch key { + case .stringKey(let stringValue): + return Key(stringValue: stringValue) + case .intKey(let intValue): + return Key(intValue: intValue) + } + }() + var newCodingPath = path + if let codingKey { + newCodingPath += [codingKey] + } + let newContext = DecodingError.Context( + codingPath: newCodingPath, + debugDescription: context.debugDescription, + underlyingError: context.underlyingError + ) + throw DecodingError.dataCorrupted(newContext) } + throw DecodingError.multipleValuesForKey(path, key) } self.content = content.mapValues { parts in @@ -60,27 +81,33 @@ final class KeyedDecoder: AbstractDecodingNode, KeyedDecodingContainerProto } func decode(_ type: T.Type, forKey key: Key) throws -> T where T : Decodable { - let data = try getData(forKey: key) - if type is AnyOptional.Type { - let node = DecodingNode(data: data, isOptional: true, path: codingPath, info: userInfo) - return try T.init(from: node) - } else if let Primitive = type as? DecodablePrimitive.Type { - return try Primitive.init(decodeFrom: data, path: codingPath + [key]) as! T - } else { - let node = DecodingNode(data: data, path: codingPath, info: userInfo) - return try T.init(from: node) + try wrapError(forKey: key) { + let data = try getData(forKey: key) + if type is AnyOptional.Type { + let node = DecodingNode(data: data, isOptional: true, path: codingPath, info: userInfo) + return try T.init(from: node) + } else if let Primitive = type as? DecodablePrimitive.Type { + return try Primitive.init(decodeFrom: data, path: codingPath + [key]) as! T + } else { + let node = DecodingNode(data: data, path: codingPath, info: userInfo) + return try T.init(from: node) + } } } func nestedContainer(keyedBy type: NestedKey.Type, forKey key: Key) throws -> KeyedDecodingContainer where NestedKey : CodingKey { - let data = try getData(forKey: key) - let container = try KeyedDecoder(data: data, path: codingPath, info: userInfo) - return KeyedDecodingContainer(container) + try wrapError(forKey: key) { + let data = try getData(forKey: key) + let container = try KeyedDecoder(data: data, path: codingPath, info: userInfo) + return KeyedDecodingContainer(container) + } } func nestedUnkeyedContainer(forKey key: Key) throws -> UnkeyedDecodingContainer { - let data = try getData(forKey: key) - return try UnkeyedDecoder(data: data, path: codingPath, info: userInfo) + try wrapError(forKey: key) { + let data = try getData(forKey: key) + return try UnkeyedDecoder(data: data, path: codingPath, info: userInfo) + } } func superDecoder() throws -> Decoder { @@ -92,4 +119,17 @@ final class KeyedDecoder: AbstractDecodingNode, KeyedDecodingContainerProto let data = try getData(forKey: key) return DecodingNode(data: data, path: codingPath, info: userInfo) } + + private func wrapError(forKey key: Key, _ block: () throws -> T) throws -> T { + do { + return try block() + } catch DecodingError.dataCorrupted(let context) { + let newContext = DecodingError.Context( + codingPath: codingPath + [key] + context.codingPath, + debugDescription: context.debugDescription, + underlyingError: context.underlyingError + ) + throw DecodingError.dataCorrupted(newContext) + } + } } diff --git a/Sources/BinaryCodable/Decoding/UnkeyedDecoder.swift b/Sources/BinaryCodable/Decoding/UnkeyedDecoder.swift index 27ac03b..d324434 100644 --- a/Sources/BinaryCodable/Decoding/UnkeyedDecoder.swift +++ b/Sources/BinaryCodable/Decoding/UnkeyedDecoder.swift @@ -50,30 +50,36 @@ final class UnkeyedDecoder: AbstractDecodingNode, UnkeyedDecodingContainer { func decode(_ type: T.Type) throws -> T where T : Decodable { defer { currentIndex += 1 } - if type is AnyOptional.Type { - let node = DecodingNode(decoder: decoder, isOptional: true, path: codingPath, info: userInfo, isInUnkeyedContainer: true) - return try T.init(from: node) - } else if let Primitive = type as? DecodablePrimitive.Type { - let dataType = Primitive.dataType - let data = try decoder.getData(for: dataType, path: codingPath) - return try Primitive.init(decodeFrom: data, path: codingPath) as! T - } else { - let node = DecodingNode(decoder: decoder, path: codingPath, info: userInfo, isInUnkeyedContainer: true) - return try T.init(from: node) + return try wrapError { + if type is AnyOptional.Type { + let node = DecodingNode(decoder: decoder, isOptional: true, path: codingPath, info: userInfo, isInUnkeyedContainer: true) + return try T.init(from: node) + } else if let Primitive = type as? DecodablePrimitive.Type { + let dataType = Primitive.dataType + let data = try decoder.getData(for: dataType, path: codingPath) + return try Primitive.init(decodeFrom: data, path: codingPath) as! T + } else { + let node = DecodingNode(decoder: decoder, path: codingPath, info: userInfo, isInUnkeyedContainer: true) + return try T.init(from: node) + } } } func nestedContainer(keyedBy type: NestedKey.Type) throws -> KeyedDecodingContainer where NestedKey : CodingKey { currentIndex += 1 - let data = try decoder.getData(for: .variableLength, path: codingPath) - let container = try KeyedDecoder(data: data, path: codingPath, info: userInfo) - return KeyedDecodingContainer(container) + return try wrapError { + let data = try decoder.getData(for: .variableLength, path: codingPath) + let container = try KeyedDecoder(data: data, path: codingPath, info: userInfo) + return KeyedDecodingContainer(container) + } } func nestedUnkeyedContainer() throws -> UnkeyedDecodingContainer { currentIndex += 1 - let data = try decoder.getData(for: .variableLength, path: codingPath) - return try UnkeyedDecoder(data: data, path: codingPath, info: userInfo) + return try wrapError { + let data = try decoder.getData(for: .variableLength, path: codingPath) + return try UnkeyedDecoder(data: data, path: codingPath, info: userInfo) + } } func superDecoder() throws -> Decoder { @@ -85,7 +91,15 @@ final class UnkeyedDecoder: AbstractDecodingNode, UnkeyedDecodingContainer { do { return try block() } catch DecodingError.dataCorrupted(let context) { - throw DecodingError.dataCorruptedError(in: self, debugDescription: context.debugDescription) + var codingPath = codingPath + codingPath.append(AnyCodingKey(intValue: currentIndex)) + codingPath.append(contentsOf: context.codingPath) + let newContext = DecodingError.Context( + codingPath: codingPath, + debugDescription: context.debugDescription, + underlyingError: context.underlyingError + ) + throw DecodingError.dataCorrupted(newContext) } } } diff --git a/Tests/BinaryCodableTests/CodingPathTests.swift b/Tests/BinaryCodableTests/CodingPathTests.swift new file mode 100644 index 0000000..203d7c1 --- /dev/null +++ b/Tests/BinaryCodableTests/CodingPathTests.swift @@ -0,0 +1,339 @@ +import XCTest +import BinaryCodable + +final class CodingPathTests: XCTestCase { + + struct KeyedBox: Codable { + enum CodingKeys: String, CodingKey { + case val + } + + let val: T + + init(_ val: T) { + self.val = val + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.val = try container.decode(T.self, forKey: .val) + } + } + + struct CorruptBool: Codable, Equatable, ExpressibleByBooleanLiteral { + let val: Bool + + init(booleanLiteral value: BooleanLiteralType) { + self.val = value + } + + init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + self.val = try container.decode(Bool.self) + if val == false { + throw DecodingError.dataCorruptedError(in: container, debugDescription: "Found false!") + } + } + + func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + try container.encode(val) + } + } + + func testStructWithCorruptArrayElement() throws { + struct Test: Codable, Equatable { + let arr: [CorruptBool] + } + + let corruptedBytes: [UInt8] = [ + 0b00111010, 118, 97, 108, // String key 'val', varint + 8, // Length of val + 0b00111010, 97, 114, 114, // String key 'arr', varint + 3, // 3 elements + 1, 1, 0 // True, true, corrupt! + ] + + let data = try BinaryEncoder().encode(KeyedBox(Test(arr: [true, true, false]))) + XCTAssertEqual(Array(data), corruptedBytes) + + do { + let _ = try BinaryDecoder().decode(KeyedBox.self, from: data) + XCTFail("Unexpected succeded!") + } catch let error as DecodingError { + guard case .dataCorrupted(let context) = error else { + XCTFail("Unexpected error!") + return + } + XCTAssertEqual(context.codingPath.map { $0.stringValue }, ["val", "arr", "2"]) + } + } + + func testStructWithArrayMissingLastElement() throws { + struct Test: Codable, Equatable { + let arr: [Bool] + } + + let corruptedBytes: [UInt8] = [ + 0b00111010, 118, 97, 108, // String key 'val', varint + 8, // Length of val + 0b00111010, 97, 114, 114, // String key 'arr', varint + 3, // 3 elements + 1, 1 // Only two elements provided! + ] + + var data = try BinaryEncoder().encode(KeyedBox(Test(arr: [true, true, false]))) + data.removeLast() + XCTAssertEqual(Array(data), corruptedBytes) + + do { + let _ = try BinaryDecoder().decode(KeyedBox.self, from: Data(corruptedBytes)) + XCTFail("Unexpected succeded!") + } catch let error as DecodingError { + guard case .dataCorrupted(let context) = error else { + XCTFail("Unexpected error!") + return + } + XCTAssertEqual(context.codingPath.map { $0.stringValue }, ["val"]) + } + } + + func testStructWithArrayMissingLastElementButCorrectLength() throws { + struct Test: Codable, Equatable { + let arr: [Bool] + } + + let corruptedBytes: [UInt8] = [ + 0b00111010, 118, 97, 108, // String key 'val', varint + 7, // Length of val + 0b00111010, 97, 114, 114, // String key 'arr', varint + 3, // 3 elements + 1, 1 // Only two elements provided! + ] + + var data = try BinaryEncoder().encode(KeyedBox(Test(arr: [true, true, false]))) + data[4] -= 1 + data.removeLast() + XCTAssertEqual(Array(data), corruptedBytes) + + do { + let _ = try BinaryDecoder().decode(KeyedBox.self, from: Data(corruptedBytes)) + XCTFail("Unexpected succeded!") + } catch let error as DecodingError { + guard case .dataCorrupted(let context) = error else { + XCTFail("Unexpected error!") + return + } + XCTAssertEqual(context.codingPath.map { $0.stringValue }, ["val", "arr"]) + } + } + + func testStructWithCorruptDataOnKeyedNestedContainer() throws { + struct Test: Codable, Equatable { + enum CodingKeys: String, CodingKey { + case nested + } + + enum NestedCodingKeys: String, CodingKey { + case bool + } + + let bool: CorruptBool + + init(bool: CorruptBool) { + self.bool = bool + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let nestedContainer = try container.nestedContainer(keyedBy: NestedCodingKeys.self, forKey: .nested) + self.bool = try nestedContainer.decode(CorruptBool.self, forKey: .bool) + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + var nestedContainer = container.nestedContainer(keyedBy: NestedCodingKeys.self, forKey: .nested) + try nestedContainer.encode(bool, forKey: .bool) + } + } + + let corruptedBytes: [UInt8] = [ + 0b00111010, 118, 97, 108, // String key 'val', varint + 14, // Byte length of val value + 0b01101010, 110, 101, 115, 116, 101, 100, // String key 'nested', varint + 6, // Byte length of nested value + 0b01001000, 98, 111, 111, 108, // String key 'bool', varint + 0, // Corrupt! + ] + + let data = try BinaryEncoder().encode(KeyedBox(Test(bool: false))) + XCTAssertEqual(Array(data), corruptedBytes) + + do { + let _ = try BinaryDecoder().decode(KeyedBox.self, from: data) + XCTFail("Unexpected succeded!") + } catch let error as DecodingError { + guard case .dataCorrupted(let context) = error else { + XCTFail("Unexpected error!") + return + } + XCTAssertEqual(context.codingPath.map { $0.stringValue }, ["val", "bool"]) + } + } + + func testStructWithCorruptDataOnUnkeyedNestedContainer() throws { + struct TestWrapper: Codable, Equatable { + enum CodingKeys: String, CodingKey { + case nested + } + + enum NestedCodingKeys: String, CodingKey { + case bool + } + + let val: Test + + init(val: Test) { + self.val = val + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + var nestedContainer = try container.nestedUnkeyedContainer(forKey: .nested) + self.val = try nestedContainer.decode(Test.self) + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + var nestedContainer = container.nestedUnkeyedContainer(forKey: .nested) + try nestedContainer.encode(val) + } + } + struct Test: Codable, Equatable { + let bool: CorruptBool + } + + let corruptedBytes: [UInt8] = [ + 0b00111010, 118, 97, 108, // String key 'val', varint + 15, // Byte length of val value + 0b01101010, 110, 101, 115, 116, 101, 100, // String key 'nested', varint + 7, // Byte length of nested value + 6, // Byte length of unkeyed container + 0b01001000, 98, 111, 111, 108, // String key 'bool', varint + 0 // Corrupt! + ] + + let data = try BinaryEncoder().encode(KeyedBox(TestWrapper(val: Test(bool: false)))) + XCTAssertEqual(Array(data), corruptedBytes) + + do { + let _ = try BinaryDecoder().decode(KeyedBox.self, from: data) + XCTFail("Unexpected succeded!") + } catch let error as DecodingError { + guard case .dataCorrupted(let context) = error else { + XCTFail("Unexpected error!") + return + } + XCTAssertEqual(context.codingPath.map { $0.stringValue }, ["val", "0", "bool"]) + } + } + + func testOptionalStructWithMissingValueAndWrongLengths() throws { + struct Test: Codable, Equatable { + let bool: Bool + } + + let corruptedBytes: [UInt8] = [ + 0b00111010, 118, 97, 108, // String key 'val', varint + 8, // Supposed byte length of optional val + 1, // 1 optional, + 6, // Supposed byte length of wrapped val + 0b01001000, 98, 111, 111, 108, // String key 'bool', varint + // No value! + ] + + let box: KeyedBox = KeyedBox(Test(bool: true)) + + var data = try BinaryEncoder().encode(box) + data.removeLast() + XCTAssertEqual(Array(data), corruptedBytes) + + do { + let _ = try BinaryDecoder().decode(KeyedBox.self, from: data) + XCTFail("Unexpected succeded!") + } catch let error as DecodingError { + guard case .dataCorrupted(let context) = error else { + XCTFail("Unexpected error!") + return + } + XCTAssertEqual(context.codingPath.map { $0.stringValue }, ["val"]) + } + } + + func testOptionalStructWithMissingValueAndWrongNestedLength() throws { + struct Test: Codable, Equatable { + let bool: Bool + } + + let corruptedBytes: [UInt8] = [ + 0b00111010, 118, 97, 108, // String key 'val', varint + 7, // Byte length of optional val (modified!) + 1, // 1 as in the optional is present, + 6, // Supposed byte length of wrapped val + 0b01001000, 98, 111, 111, 108, // String key 'bool', varint + // No value! + ] + + let box: KeyedBox = KeyedBox(Test(bool: true)) + + var data = try BinaryEncoder().encode(box) + data[4] -= 1 + data.removeLast() + XCTAssertEqual(Array(data), corruptedBytes) + + do { + let _ = try BinaryDecoder().decode(KeyedBox.self, from: data) + XCTFail("Unexpected succeded!") + } catch let error as DecodingError { + guard case .dataCorrupted(let context) = error else { + XCTFail("Unexpected error!") + return + } + XCTAssertEqual(context.codingPath.map { $0.stringValue }, ["val"]) + } + } + + func testOptionalStructWithMissingValueButCorrectByteLengths() throws { + struct Test: Codable, Equatable { + let bool: Bool + } + + let corruptedBytes: [UInt8] = [ + 0b00111010, 118, 97, 108, // String key 'val', varint + 7, // Byte length of optional val (modified!) + 1, // 1 as in the optional is present, + 5, // Byte length of wrapped val (modified!) + 0b01001000, 98, 111, 111, 108, // String key 'bool', varint + // No value! + ] + + let box: KeyedBox = KeyedBox(Test(bool: true)) + + var data = try BinaryEncoder().encode(box) + data[4] -= 1 + data[6] -= 1 + data.removeLast() + XCTAssertEqual(Array(data), corruptedBytes) + + do { + let _ = try BinaryDecoder().decode(KeyedBox.self, from: data) + XCTFail("Unexpected succeded!") + } catch let error as DecodingError { + guard case .dataCorrupted(let context) = error else { + XCTFail("Unexpected error!") + return + } + XCTAssertEqual(context.codingPath.map { $0.stringValue }, ["val", "bool"]) + } + } +}