Skip to content

Commit

Permalink
Merge pull request #2415 from square/jszumski/swift-proto3-identity-v…
Browse files Browse the repository at this point in the history
…alues

Swift: Improve generated code for proto3 messages
  • Loading branch information
jszumski authored Mar 28, 2023
2 parents e7ee153 + 9d764b6 commit 7878118
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 22 deletions.
15 changes: 15 additions & 0 deletions wire-runtime-swift/src/main/swift/ProtoCodable/ProtoCodable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,18 @@ extension ProtoDecodable {
}

}

extension ProtoEnum where Self : RawRepresentable, RawValue == UInt32 {

/**
A convenience function used with enum fields that throws an error if the field is null
and its default value can't be used instead.
*/
public static func defaultIfMissing(_ value: Self?) throws -> Self {
guard let value = value ?? Self(rawValue: 0) else {
throw ProtoDecoder.Error.missingEnumDefaultValue(type: Self.self)
}
return value
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public final class ProtoDecoder {
case mapEntryWithoutKey(value: Any?)
case mapEntryWithoutValue(key: Any)
case messageWithoutLength
case missingEnumDefaultValue(type: Any.Type)
case missingRequiredField(typeName: String, fieldName: String)
case recursionLimitExceeded
case unexpectedEndOfData
Expand Down Expand Up @@ -80,6 +81,8 @@ public final class ProtoDecoder {
return "Map entry with \(key) did not include a value."
case .messageWithoutLength:
return "Attempting to decode a message without first decoding the length of that message."
case let .missingEnumDefaultValue(type):
return "Could not assign a default value of 0 for enum type \(String(describing: type))"
case let .missingRequiredField(typeName, fieldName):
return "Required field \(fieldName) for type \(typeName) is not included in the message data."
case let .boxedValueMissingField(type):
Expand Down
15 changes: 15 additions & 0 deletions wire-runtime-swift/src/main/swift/ProtoCodable/ProtoWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,21 @@ public final class ProtoWriter {
try value.encode(to: self)
}

/** Encode a required `bool` field */
public func encode(tag: UInt32, value: Bool) throws {
if value == false && isProto3 { return }
try encode(tag: tag, value: value as Bool?)
}

/** Encode an optional `bool` field */
public func encode(tag: UInt32, value: Bool?) throws {
guard let value = value else { return }

let key = ProtoWriter.makeFieldKey(tag: tag, wireType: .varint)
writeVarint(key)
try value.encode(to: self)
}

/** Encode a required `int32`, `sfixed32`, or `sint32` field */
public func encode(tag: UInt32, value: Int32, encoding: ProtoIntEncoding = .variable) throws {
// Don't encode default values if using proto3 syntax.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ extension Duration : ProtoMessage {

extension Duration : Proto3Codable {
public init(from reader: ProtoReader) throws {
var seconds: Int64? = nil
var nanos: Int32? = nil
var seconds: Int64 = 0
var nanos: Int32 = 0

let token = try reader.beginMessage()
while let tag = try reader.nextTag(token: token) {
Expand All @@ -105,8 +105,8 @@ extension Duration : Proto3Codable {
}
self.unknownFields = try reader.endMessage(token: token)

self.seconds = try Duration.checkIfMissing(seconds, "seconds")
self.nanos = try Duration.checkIfMissing(nanos, "nanos")
self.seconds = seconds
self.nanos = nanos
}

public func encode(to writer: ProtoWriter) throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ extension Timestamp : ProtoMessage {

extension Timestamp : Proto3Codable {
public init(from reader: ProtoReader) throws {
var seconds: Int64? = nil
var nanos: Int32? = nil
var seconds: Int64 = 0
var nanos: Int32 = 0

let token = try reader.beginMessage()
while let tag = try reader.nextTag(token: token) {
Expand All @@ -117,8 +117,8 @@ extension Timestamp : Proto3Codable {
}
self.unknownFields = try reader.endMessage(token: token)

self.seconds = try Timestamp.checkIfMissing(seconds, "seconds")
self.nanos = try Timestamp.checkIfMissing(nanos, "nanos")
self.seconds = seconds
self.nanos = nanos
}

public func encode(to writer: ProtoWriter) throws {
Expand Down
16 changes: 16 additions & 0 deletions wire-runtime-swift/src/test/proto/empty.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,21 @@ message EmptyMessage {
}

message EmptyOmitted {
enum EmptyEnum {
UNKNOWN = 0;
OTHER = 1;
}

message EmptyNested {
int32 nested = 1;
}

int32 numeric_value = 1;
string string_value = 2;
bytes bytes_value = 3;
bool bool_value = 4;
EmptyEnum enum_value = 5;
EmptyNested message_value = 6;
repeated string repeated_value = 7;
map<int32, string> map_value = 8;
}
2 changes: 1 addition & 1 deletion wire-runtime-swift/src/test/swift/ProtoEncoderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ final class ProtoEncoderTests: XCTestCase {
}

func testEncodeEmptyProtoMessageWithIdentityValues() throws {
let object = EmptyOmitted(numeric_value: 0)
let object = EmptyOmitted(numeric_value: 0, string_value: "", bytes_value: .init(), bool_value: false, enum_value: .UNKNOWN)
let encoder = ProtoEncoder()
let data = try encoder.encode(object)

Expand Down
22 changes: 22 additions & 0 deletions wire-runtime-swift/src/test/swift/RoundTripTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,26 @@ final class RoundTripTests: XCTestCase {
XCTAssertEqual(decodedPerson, person)
}

// ensure that fields set to their identity value survive a roundtrip when omitted over the wire
func testProto3IdentityValues() throws {
let empty = EmptyOmitted(
numeric_value: 0,
string_value: "",
bytes_value: Data(),
bool_value: false,
enum_value: .UNKNOWN,
message_value: nil,
repeated_value: [],
map_value: [:]
)

let encoder = ProtoEncoder()
let data = try encoder.encode(empty)

let decoder = ProtoDecoder()
let decodedEmpty = try decoder.decode(EmptyOmitted.self, from: data)

XCTAssertEqual(decodedEmpty, empty)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,22 @@ class SwiftGenerator private constructor(
else -> null
}

// see https://protobuf.dev/programming-guides/proto3/#default
private val Field.proto3InitialValue: String
get() = when {
isMap -> "[:]"
isRepeated -> "[]"
isOptional -> "nil"
else -> when (typeName.makeNonOptional()) {
BOOL -> "false"
DOUBLE, FLOAT -> "0"
INT32, UINT32, INT64, UINT64 -> "0"
STRING -> """""""" // evaluates to the empty string
DATA -> ".init()"
else -> "nil"
}
}

private val Field.codableName: String?
get() = jsonName?.takeIf { it != name } ?: camelCase(name).takeIf { it != name }

Expand Down Expand Up @@ -192,6 +208,27 @@ class SwiftGenerator private constructor(

private val MessageType.isHeapAllocated get() = fields.size + oneOfs.size >= 16

/**
* Checks that every enum in a proto3 message contains a value with tag 0.
*
* @throws NoSuchElementException if the case doesn't exist
*/
@Throws(NoSuchElementException::class)
private fun validateProto3DefaultsExist(type: MessageType) {
if (type.syntax == PROTO_2) { return }

// validate each enum field
type
.fields
.mapNotNull { schema.getType(it.type!!) as? EnumType }
.forEach { enum ->
// ensure that a 0 case exists
if (enum.constants.filter { it.tag == 0 }.isEmpty()) {
throw NoSuchElementException("Missing a zero value for ${enum.name}")
}
}
}

@OptIn(ExperimentalStdlibApi::class) // TODO move to build flag
private fun generateMessage(
type: MessageType,
Expand All @@ -208,6 +245,8 @@ class SwiftGenerator private constructor(

val typeSpecs = mutableListOf<TypeSpec>()

validateProto3DefaultsExist(type)

typeSpecs += TypeSpec.structBuilder(structType)
.addModifiers(PUBLIC)
.apply {
Expand Down Expand Up @@ -449,18 +488,30 @@ class SwiftGenerator private constructor(
.addParameter("from", reader, protoReader)
.throws(true)
.apply {
// Declare locals into which everything is writen before promoting to members.
// Declare locals into which everything is written before promoting to members.
type.fields.forEach { field ->
val localType = if (field.isRepeated || field.isMap) {
field.typeName
} else {
field.typeName.makeOptional()
val localType = when (type.syntax) {
PROTO_2 -> if (field.isRepeated || field.isMap) {
field.typeName
} else {
field.typeName.makeOptional()
}
PROTO_3 -> if (field.isOptional || (field.isEnum && !field.isRepeated)) {
field.typeName.makeOptional()
} else {
field.typeName
}
}
val initializer = when {
field.isMap -> "[:]"
field.isRepeated -> "[]"
else -> "nil"

val initializer = when (type.syntax) {
PROTO_2 -> when {
field.isMap -> "[:]"
field.isRepeated -> "[]"
else -> "nil"
}
PROTO_3 -> field.proto3InitialValue
}

addStatement("var %N: %T = %L", field.name, localType, initializer)
}
type.oneOfs.forEach { oneOf ->
Expand Down Expand Up @@ -533,10 +584,17 @@ class SwiftGenerator private constructor(
// Check required and bind members.
addStatement("")
type.fields.forEach { field ->
val initializer = if (field.isOptional || field.isRepeated || field.isMap) {
CodeBlock.of("%N", field.name)
} else {
CodeBlock.of("try %1T.checkIfMissing(%2N, %2S)", structType, field.name)
val initializer = when(type.syntax) {
PROTO_2 -> if (field.isOptional || field.isRepeated || field.isMap) {
CodeBlock.of("%N", field.name)
} else {
CodeBlock.of("try %1T.checkIfMissing(%2N, %2S)", structType, field.name)
}
PROTO_3 -> if (field.isEnum && !field.isRepeated) {
CodeBlock.of("try %1T.defaultIfMissing(%2N)", field.typeName.makeNonOptional(), field.name)
} else {
CodeBlock.of("%N", field.name)
}
}
addStatement("self.%N = %L", field.name, initializer)
}
Expand Down

0 comments on commit 7878118

Please sign in to comment.