From 94b85617b02ed52359847365a5a6b76ed38c32b4 Mon Sep 17 00:00:00 2001 From: Jesper Lundgren Date: Mon, 2 Mar 2015 13:36:02 +0800 Subject: [PATCH 1/2] simple required fields check on message deserialization --- .../scalabuff/compiler/Generator.scala | 11 +- .../resources/generated/RequiredFields.scala | 213 ++++++++++++++++++ .../test/resources/parsed/RequiredFields.txt | 1 + .../resources/proto/required_fields.proto | 12 + 4 files changed, 236 insertions(+), 1 deletion(-) create mode 100644 scalabuff-compiler/src/test/resources/generated/RequiredFields.scala create mode 100644 scalabuff-compiler/src/test/resources/parsed/RequiredFields.txt create mode 100644 scalabuff-compiler/src/test/resources/proto/required_fields.proto diff --git a/scalabuff-compiler/src/main/net/sandrogrzicic/scalabuff/compiler/Generator.scala b/scalabuff-compiler/src/main/net/sandrogrzicic/scalabuff/compiler/Generator.scala index d14e1c0..aeddee3 100644 --- a/scalabuff-compiler/src/main/net/sandrogrzicic/scalabuff/compiler/Generator.scala +++ b/scalabuff-compiler/src/main/net/sandrogrzicic/scalabuff/compiler/Generator.scala @@ -279,12 +279,20 @@ class Generator protected (sourceName: String, importedSymbols: Map[String, Impo case _ => // "missing combination " } } + out.append("\n") + .append(indent2).append("import scala.collection.JavaConversions.mapAsScalaMap") + out.append("\n") + .append(indent2).append("val required: scala.collection.mutable.Map[String, Boolean] = new java.util.HashMap[String, Boolean]") out.append("\n") .append(indent2).append("def __newMerged = ").append(name).append("(\n") fields.foreach { field => out.append(indent3) if (field.label == REPEATED) out.append("Vector(") - out.append(field.name.toTemporaryIdent) + if (field.label == REQUIRED) { + out.append("if (required(\"").append(field.name).append("\")) ") + out.append(field.name.toTemporaryIdent).append(" else ").append(field.name.toTemporaryIdent) + } + else out.append(field.name.toTemporaryIdent) if (field.label == REPEATED) out.append(": _*)") out.append(",\n") } @@ -300,6 +308,7 @@ class Generator protected (sourceName: String, importedSymbols: Map[String, Impo case _ => false } out.append(indent3).append("case ").append((field.number << 3) | field.fType.wireType).append(" => ") + if (field.label == REQUIRED) out.append("required(\"").append(field.name).append("\") = true; ") out.append(field.name.toTemporaryIdent).append(" ") if (field.label == REPEATED) out.append("+") diff --git a/scalabuff-compiler/src/test/resources/generated/RequiredFields.scala b/scalabuff-compiler/src/test/resources/generated/RequiredFields.scala new file mode 100644 index 0000000..34f9ff7 --- /dev/null +++ b/scalabuff-compiler/src/test/resources/generated/RequiredFields.scala @@ -0,0 +1,213 @@ +// Generated by ScalaBuff, the Scala Protocol Buffers compiler. DO NOT EDIT! +// source: required_fields.proto + +package resources.generated + +final case class Required_v1 ( + `requiredField1`: Int = 0, + `requiredField2`: String = "" +) extends com.google.protobuf.GeneratedMessageLite + with com.google.protobuf.MessageLite.Builder + with net.sandrogrzicic.scalabuff.Message[Required_v1] + with net.sandrogrzicic.scalabuff.Parser[Required_v1] { + + + + def writeTo(output: com.google.protobuf.CodedOutputStream) { + output.writeInt32(1, `requiredField1`) + output.writeString(2, `requiredField2`) + } + + def getSerializedSize = { + import com.google.protobuf.CodedOutputStream._ + var __size = 0 + __size += computeInt32Size(1, `requiredField1`) + __size += computeStringSize(2, `requiredField2`) + + __size + } + + def mergeFrom(in: com.google.protobuf.CodedInputStream, extensionRegistry: com.google.protobuf.ExtensionRegistryLite): Required_v1 = { + import com.google.protobuf.ExtensionRegistryLite.{getEmptyRegistry => _emptyRegistry} + var __requiredField1: Int = 0 + var __requiredField2: String = "" + + import scala.collection.JavaConversions.mapAsScalaMap + val required: scala.collection.mutable.Map[String, Boolean] = new java.util.HashMap[String, Boolean] + def __newMerged = Required_v1( + if (required("required_field_1")) __requiredField1 else __requiredField1, + if (required("required_field_2")) __requiredField2 else __requiredField2 + ) + while (true) in.readTag match { + case 0 => return __newMerged + case 8 => required("required_field_1") = true; __requiredField1 = in.readInt32() + case 18 => required("required_field_2") = true; __requiredField2 = in.readString() + case default => if (!in.skipField(default)) return __newMerged + } + null + } + + def mergeFrom(m: Required_v1) = { + Required_v1( + m.`requiredField1`, + m.`requiredField2` + ) + } + + def getDefaultInstanceForType = Required_v1.defaultInstance + def clear = getDefaultInstanceForType + def isInitialized = true + def build = this + def buildPartial = this + def parsePartialFrom(cis: com.google.protobuf.CodedInputStream, er: com.google.protobuf.ExtensionRegistryLite) = mergeFrom(cis, er) + override def getParserForType = this + def newBuilderForType = getDefaultInstanceForType + def toBuilder = this + def toJson(indent: Int = 0): String = { + val indent0 = "\n" + ("\t" * indent) + val (indent1, indent2) = (indent0 + "\t", indent0 + "\t\t") + val sb = StringBuilder.newBuilder + sb + .append("{") + sb.append(indent1).append("\"requiredField1\": ").append("\"").append(`requiredField1`).append("\"").append(',') + sb.append(indent1).append("\"requiredField2\": ").append("\"").append(`requiredField2`).append("\"").append(',') + if (sb.last.equals(',')) sb.length -= 1 + sb.append(indent0).append("}") + sb.toString() + } + +} + +object Required_v1 { + @scala.beans.BeanProperty val defaultInstance = new Required_v1() + + def parseFrom(data: Array[Byte]): Required_v1 = defaultInstance.mergeFrom(data) + def parseFrom(data: Array[Byte], offset: Int, length: Int): Required_v1 = defaultInstance.mergeFrom(data, offset, length) + def parseFrom(byteString: com.google.protobuf.ByteString): Required_v1 = defaultInstance.mergeFrom(byteString) + def parseFrom(stream: java.io.InputStream): Required_v1 = defaultInstance.mergeFrom(stream) + def parseDelimitedFrom(stream: java.io.InputStream): Option[Required_v1] = defaultInstance.mergeDelimitedFromStream(stream) + + val REQUIRED_FIELD_1_FIELD_NUMBER = 1 + val REQUIRED_FIELD_2_FIELD_NUMBER = 2 + + def newBuilder = defaultInstance.newBuilderForType + def newBuilder(prototype: Required_v1) = defaultInstance.mergeFrom(prototype) + +} +final case class Required_v2 ( + `requiredField1`: Int = 0, + `requiredField2`: String = "", + `requiredField3`: String = "" +) extends com.google.protobuf.GeneratedMessageLite + with com.google.protobuf.MessageLite.Builder + with net.sandrogrzicic.scalabuff.Message[Required_v2] + with net.sandrogrzicic.scalabuff.Parser[Required_v2] { + + + + def writeTo(output: com.google.protobuf.CodedOutputStream) { + output.writeInt32(1, `requiredField1`) + output.writeString(2, `requiredField2`) + output.writeString(3, `requiredField3`) + } + + def getSerializedSize = { + import com.google.protobuf.CodedOutputStream._ + var __size = 0 + __size += computeInt32Size(1, `requiredField1`) + __size += computeStringSize(2, `requiredField2`) + __size += computeStringSize(3, `requiredField3`) + + __size + } + + def mergeFrom(in: com.google.protobuf.CodedInputStream, extensionRegistry: com.google.protobuf.ExtensionRegistryLite): Required_v2 = { + import com.google.protobuf.ExtensionRegistryLite.{getEmptyRegistry => _emptyRegistry} + var __requiredField1: Int = 0 + var __requiredField2: String = "" + var __requiredField3: String = "" + + import scala.collection.JavaConversions.mapAsScalaMap + val required: scala.collection.mutable.Map[String, Boolean] = new java.util.HashMap[String, Boolean] + def __newMerged = Required_v2( + if (required("required_field_1")) __requiredField1 else __requiredField1, + if (required("required_field_2")) __requiredField2 else __requiredField2, + if (required("required_field_3")) __requiredField3 else __requiredField3 + ) + while (true) in.readTag match { + case 0 => return __newMerged + case 8 => required("required_field_1") = true; __requiredField1 = in.readInt32() + case 18 => required("required_field_2") = true; __requiredField2 = in.readString() + case 26 => required("required_field_3") = true; __requiredField3 = in.readString() + case default => if (!in.skipField(default)) return __newMerged + } + null + } + + def mergeFrom(m: Required_v2) = { + Required_v2( + m.`requiredField1`, + m.`requiredField2`, + m.`requiredField3` + ) + } + + def getDefaultInstanceForType = Required_v2.defaultInstance + def clear = getDefaultInstanceForType + def isInitialized = true + def build = this + def buildPartial = this + def parsePartialFrom(cis: com.google.protobuf.CodedInputStream, er: com.google.protobuf.ExtensionRegistryLite) = mergeFrom(cis, er) + override def getParserForType = this + def newBuilderForType = getDefaultInstanceForType + def toBuilder = this + def toJson(indent: Int = 0): String = { + val indent0 = "\n" + ("\t" * indent) + val (indent1, indent2) = (indent0 + "\t", indent0 + "\t\t") + val sb = StringBuilder.newBuilder + sb + .append("{") + sb.append(indent1).append("\"requiredField1\": ").append("\"").append(`requiredField1`).append("\"").append(',') + sb.append(indent1).append("\"requiredField2\": ").append("\"").append(`requiredField2`).append("\"").append(',') + sb.append(indent1).append("\"requiredField3\": ").append("\"").append(`requiredField3`).append("\"").append(',') + if (sb.last.equals(',')) sb.length -= 1 + sb.append(indent0).append("}") + sb.toString() + } + +} + +object Required_v2 { + @scala.beans.BeanProperty val defaultInstance = new Required_v2() + + def parseFrom(data: Array[Byte]): Required_v2 = defaultInstance.mergeFrom(data) + def parseFrom(data: Array[Byte], offset: Int, length: Int): Required_v2 = defaultInstance.mergeFrom(data, offset, length) + def parseFrom(byteString: com.google.protobuf.ByteString): Required_v2 = defaultInstance.mergeFrom(byteString) + def parseFrom(stream: java.io.InputStream): Required_v2 = defaultInstance.mergeFrom(stream) + def parseDelimitedFrom(stream: java.io.InputStream): Option[Required_v2] = defaultInstance.mergeDelimitedFromStream(stream) + + val REQUIRED_FIELD_1_FIELD_NUMBER = 1 + val REQUIRED_FIELD_2_FIELD_NUMBER = 2 + val REQUIRED_FIELD_3_FIELD_NUMBER = 3 + + def newBuilder = defaultInstance.newBuilderForType + def newBuilder(prototype: Required_v2) = defaultInstance.mergeFrom(prototype) + +} + +object RequiredFields { + def registerAllExtensions(registry: com.google.protobuf.ExtensionRegistryLite) { + } + + private val fromBinaryHintMap = collection.immutable.HashMap[String, Array[Byte] ⇒ com.google.protobuf.GeneratedMessageLite]( + "Required_v1" -> (bytes ⇒ Required_v1.parseFrom(bytes)), + "Required_v2" -> (bytes ⇒ Required_v2.parseFrom(bytes)) + ) + + def deserializePayload(payload: Array[Byte], payloadType: String): com.google.protobuf.GeneratedMessageLite = { + fromBinaryHintMap.get(payloadType) match { + case Some(f) ⇒ f(payload) + case None ⇒ throw new IllegalArgumentException(s"unimplemented deserialization of message payload of type [${payloadType}]") + } + } +} diff --git a/scalabuff-compiler/src/test/resources/parsed/RequiredFields.txt b/scalabuff-compiler/src/test/resources/parsed/RequiredFields.txt new file mode 100644 index 0000000..df9b94b --- /dev/null +++ b/scalabuff-compiler/src/test/resources/parsed/RequiredFields.txt @@ -0,0 +1 @@ +List(PackageStatement(resources.generated), Message(Required_v1,MessageBody(List(Field(required,Int32,required_field_1,1,List(),), Field(required,String,required_field_2,2,List(),)),List(),List(),List(),List(),List(),List())), Message(Required_v2,MessageBody(List(Field(required,Int32,required_field_1,1,List(),), Field(required,String,required_field_2,2,List(),), Field(required,String,required_field_3,3,List(),)),List(),List(),List(),List(),List(),List()))) diff --git a/scalabuff-compiler/src/test/resources/proto/required_fields.proto b/scalabuff-compiler/src/test/resources/proto/required_fields.proto new file mode 100644 index 0000000..778ac44 --- /dev/null +++ b/scalabuff-compiler/src/test/resources/proto/required_fields.proto @@ -0,0 +1,12 @@ +package resources.generated; + +message Required_v1 { + required int32 required_field_1 = 1; + required string required_field_2 = 2; +} + +message Required_v2 { + required int32 required_field_1 = 1; + required string required_field_2 = 2; + required string required_field_3 = 3; +} From a4f0748e9866d50b4b5a7dab56ba90ed7f5f79ac Mon Sep 17 00:00:00 2001 From: Jesper Lundgren Date: Mon, 2 Mar 2015 13:36:56 +0800 Subject: [PATCH 2/2] test for deserialization required field check --- .../src/test/tests/MessageTest.scala | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/scalabuff-compiler/src/test/tests/MessageTest.scala b/scalabuff-compiler/src/test/tests/MessageTest.scala index d5fbb0a..f426c70 100644 --- a/scalabuff-compiler/src/test/tests/MessageTest.scala +++ b/scalabuff-compiler/src/test/tests/MessageTest.scala @@ -5,6 +5,7 @@ import resources.generated._ import com.google.protobuf._ import scala.collection._ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import scala.util.{Try, Success, Failure} /** * Tests whether generated Scala classes function correctly. @@ -43,6 +44,25 @@ class MessageTest extends FunSuite with Matchers { received.repeatedBytesField should equal (sent.repeatedBytesField) } + test("Missing required field") { + val v1 = Required_v1(3,"test") + val v2 = Required_v2(2,"test2","test3") + + + val received1 = Required_v1.defaultInstance.mergeFrom(v2.toByteArray) + val received2 = Try { Required_v2.defaultInstance.mergeFrom(v1.toByteArray) } + + val msg = received2 match { + case Success(s) => s.toString + case Failure(f) => f.getMessage + } + + received1.`requiredField1` should equal(2) + received1.`requiredField2` should equal("test2") + received2.isFailure should equal(true) + msg should equal("key not found: required_field_3") + } + test("object.parseFrom") { val message = ComplexMessage(ByteString.copyFromUtf8("Sandro Gržičić"))