From 36c80683ea34d7de690418314776d0896bc9ef5c Mon Sep 17 00:00:00 2001 From: Ba Le Xuan Date: Wed, 3 Apr 2024 19:37:41 +0200 Subject: [PATCH] Support ObjectId in schema-bson --- .../zio/schema/codec/BsonSchemaCodec.scala | 240 +++++++++++------- .../main/scala/zio/schema/codec/package.scala | 22 ++ .../schema/codec/BsonSchemaCodecSpec.scala | 41 +++ 3 files changed, 207 insertions(+), 96 deletions(-) create mode 100644 zio-schema-bson/src/main/scala/zio/schema/codec/package.scala diff --git a/zio-schema-bson/src/main/scala/zio/schema/codec/BsonSchemaCodec.scala b/zio-schema-bson/src/main/scala/zio/schema/codec/BsonSchemaCodec.scala index cc2557f98..4150f700d 100644 --- a/zio-schema-bson/src/main/scala/zio/schema/codec/BsonSchemaCodec.scala +++ b/zio-schema-bson/src/main/scala/zio/schema/codec/BsonSchemaCodec.scala @@ -7,6 +7,7 @@ import scala.collection.compat._ import scala.collection.immutable.{ HashMap, ListMap } import scala.jdk.CollectionConverters._ +import org.bson.types.ObjectId import org.bson.{ BsonDocument, BsonNull, BsonReader, BsonType, BsonValue, BsonWriter } import zio.bson.BsonBuilder._ @@ -504,6 +505,10 @@ object BsonSchemaCodec { directEncoder => override def encode(writer: BsonWriter, value: DynamicValue, ctx: BsonEncoder.EncoderContext): Unit = value match { + case DynamicValue.Record(_, values) if values.headOption.exists(_._1 == bson.ObjectIdTag) => + val id = values.head._2.toTypedValueOption[String].get + writer.writeObjectId(new ObjectId(id)) + case DynamicValue.Record(_, values) => val nextCtx = BsonEncoder.EncoderContext.default @@ -546,6 +551,11 @@ object BsonSchemaCodec { override def toBsonValue(value: DynamicValue): BsonValue = value match { + case DynamicValue.Record(_, values) if values.headOption.exists(_._1 == bson.ObjectIdTag) => + val id = values.head._2.toTypedValueOption[String].get + val objectId = new ObjectId(id) + objectId.toBsonValue + case DynamicValue.Record(_, values) => new BsonDocument(values.view.map { case (key, value) => element(key, directEncoder.toBsonValue(value)) @@ -834,7 +844,15 @@ object BsonSchemaCodec { DynamicValue.Primitive(Chunk.fromArray(bsonValue.asBinary().getData), StandardType.BinaryType) case BsonType.UNDEFINED => DynamicValue.NoneValue case BsonType.OBJECT_ID => - DynamicValue.Primitive(bsonValue.asObjectId().getValue.toHexString, StandardType.StringType) + DynamicValue.Record( + TypeId.Structural, + ListMap( + bson.ObjectIdTag -> DynamicValue.Primitive( + bsonValue.asObjectId().getValue.toHexString, + StandardType.StringType + ) + ) + ) case BsonType.BOOLEAN => DynamicValue.Primitive(bsonValue.asBoolean().getValue, StandardType.BoolType) case BsonType.DATE_TIME => DynamicValue.Primitive(Instant.ofEpochMilli(bsonValue.asDateTime().getValue), StandardType.InstantType) @@ -1145,39 +1163,49 @@ object BsonSchemaCodec { private val len = nonTransientFields.length - def encode(writer: BsonWriter, value: Z, ctx: BsonEncoder.EncoderContext): Unit = { - val nextCtx = ctx.copy(inlineNextObject = false) + def encode(writer: BsonWriter, value: Z, ctx: BsonEncoder.EncoderContext): Unit = + if (names.size == 1 && names(0) == bson.ObjectIdTag) { + val fieldValue = nonTransientFields(0).get(value) + val id = new ObjectId(fieldValue.toString) + writer.writeObjectId(id) + } else { + val nextCtx = ctx.copy(inlineNextObject = false) + + if (!ctx.inlineNextObject) writer.writeStartDocument() - if (!ctx.inlineNextObject) writer.writeStartDocument() + var i = 0 - var i = 0 + while (i < len) { + val tc = tcs(i) + val fieldValue = nonTransientFields(i).get(value) - while (i < len) { - val tc = tcs(i) - val fieldValue = nonTransientFields(i).get(value) + if (keepNulls || !tc.isAbsent(fieldValue)) { + writer.writeName(names(i)) + tc.encode(writer, fieldValue, nextCtx) + } - if (keepNulls || !tc.isAbsent(fieldValue)) { - writer.writeName(names(i)) - tc.encode(writer, fieldValue, nextCtx) + i += 1 } - i += 1 + if (!ctx.inlineNextObject) writer.writeEndDocument() } - if (!ctx.inlineNextObject) writer.writeEndDocument() - } - - def toBsonValue(value: Z): BsonValue = { - val elements = nonTransientFields.indices.view.flatMap { idx => - val fieldValue = nonTransientFields(idx).get(value) - val tc = tcs(idx) + def toBsonValue(value: Z): BsonValue = + if (names.size == 1 && names(0) == bson.ObjectIdTag) { + val fieldValue = nonTransientFields(0).get(value) + val id = new ObjectId(fieldValue.toString) + id.toBsonValue + } else { + val elements = nonTransientFields.indices.view.flatMap { idx => + val fieldValue = nonTransientFields(idx).get(value) + val tc = tcs(idx) - if (keepNulls || !tc.isAbsent(fieldValue)) Some(element(names(idx), tc.toBsonValue(fieldValue))) - else None - }.to(Chunk) + if (keepNulls || !tc.isAbsent(fieldValue)) Some(element(names(idx), tc.toBsonValue(fieldValue))) + else None + }.to(Chunk) - new BsonDocument(elements.asJava) - } + new BsonDocument(elements.asJava) + } } } @@ -1189,7 +1217,6 @@ object BsonSchemaCodec { private[codec] def caseClassDecoder[Z](caseClassSchema: Schema.Record[Z]): BsonDecoder[Z] = { val fields = caseClassSchema.fields val len: Int = fields.length - Array.ofDim[Any](len) val fieldNames = fields.map { f => f.annotations.collectFirst { case bsonField(n) => n }.getOrElse(f.name.asInstanceOf[String]) }.toArray @@ -1209,91 +1236,112 @@ object BsonSchemaCodec { lazy val tcs: Array[BsonDecoder[Any]] = schemas.map(s => schemaDecoder(s).asInstanceOf[BsonDecoder[Any]]) new BsonDecoder[Z] { - def decodeUnsafe(reader: BsonReader, trace: List[BsonTrace], ctx: BsonDecoder.BsonDecoderContext): Z = unsafeCall(trace) { - reader.readStartDocument() - - val nextCtx = BsonDecoder.BsonDecoderContext.default - val ps: Array[Any] = Array.ofDim(len) - - while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { - val name = reader.readName() - val idx = indexes.getOrElse(name, -1) - - if (idx >= 0) { - val nextTrace = spans(idx) :: trace - val tc = tcs(idx) - if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate") - ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) { - val opt = BsonDecoder.option(tc).decodeUnsafe(reader, nextTrace, nextCtx) - opt.getOrElse(fields(idx).defaultValue.get) - } else { - tc.decodeUnsafe(reader, nextTrace, nextCtx) + def decodeUnsafe(reader: BsonReader, trace: List[BsonTrace], ctx: BsonDecoder.BsonDecoderContext): Z = + if (fieldNames.size == 1 && fieldNames(0) == bson.ObjectIdTag) { + val id = reader.readObjectId.toHexString + Unsafe.unsafe { implicit u => + caseClassSchema.construct(Chunk.fromArray(Array(id))) match { + case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err") + case Right(value) => value + } + } + } else { + unsafeCall(trace) { + reader.readStartDocument() + + val nextCtx = BsonDecoder.BsonDecoderContext.default + val ps: Array[Any] = Array.ofDim(len) + + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + val name = reader.readName() + val idx = indexes.getOrElse(name, -1) + + if (idx >= 0) { + val nextTrace = spans(idx) :: trace + val tc = tcs(idx) + if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate") + ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) { + val opt = BsonDecoder.option(tc).decodeUnsafe(reader, nextTrace, nextCtx) + opt.getOrElse(fields(idx).defaultValue.get) + } else { + tc.decodeUnsafe(reader, nextTrace, nextCtx) + } + } else if (noExtra && !ctx.ignoreExtraField.contains(name)) { + throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.") + } else reader.skipValue() } - } else if (noExtra && !ctx.ignoreExtraField.contains(name)) { - throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.") - } else reader.skipValue() - } - var i = 0 - while (i < len) { - if (ps(i) == null) { - if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) { - ps(i) = fields(i).defaultValue.get - } else { - ps(i) = tcs(i).decodeMissingUnsafe(spans(i) :: trace) + var i = 0 + while (i < len) { + if (ps(i) == null) { + if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) { + ps(i) = fields(i).defaultValue.get + } else { + ps(i) = tcs(i).decodeMissingUnsafe(spans(i) :: trace) + } + } + i += 1 } - } - i += 1 - } - reader.readEndDocument() + reader.readEndDocument() - Unsafe.unsafe { implicit u => - caseClassSchema.construct(Chunk.fromArray(ps)) match { - case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err") - case Right(value) => value + Unsafe.unsafe { implicit u => + caseClassSchema.construct(Chunk.fromArray(ps)) match { + case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err") + case Right(value) => value + } + } } } - } def fromBsonValueUnsafe(value: BsonValue, trace: List[BsonTrace], ctx: BsonDecoder.BsonDecoderContext): Z = - assumeType(trace)(BsonType.DOCUMENT, value) { value => - val nextCtx = BsonDecoder.BsonDecoderContext.default - val ps: Array[Any] = Array.ofDim(len) - - value.asDocument().asScala.foreachEntry { (name, value) => - val idx = indexes.getOrElse(name, -1) - - if (idx >= 0) { - val nextTrace = spans(idx) :: trace - val tc = tcs(idx) - if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate") - ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) { - val opt = BsonDecoder.option(tc).fromBsonValueUnsafe(value, nextTrace, nextCtx) - opt.getOrElse(fields(idx).defaultValue.get) - } else { - tc.fromBsonValueUnsafe(value, nextTrace, nextCtx) - } - } else if (noExtra && !ctx.ignoreExtraField.contains(name)) - throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.") + if (value.getBsonType == BsonType.OBJECT_ID) { + Unsafe.unsafe { implicit u => + val ps: Array[Any] = Array(value.asObjectId.getValue.toHexString) + caseClassSchema.construct(Chunk.fromArray(ps)) match { + case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err") + case Right(value) => value + } } + } else { + assumeType(trace)(BsonType.DOCUMENT, value) { value => + val nextCtx = BsonDecoder.BsonDecoderContext.default + val ps: Array[Any] = Array.ofDim(len) + + value.asDocument().asScala.foreachEntry { (name, value) => + val idx = indexes.getOrElse(name, -1) + + if (idx >= 0) { + val nextTrace = spans(idx) :: trace + val tc = tcs(idx) + if (ps(idx) != null) throw BsonDecoder.Error(nextTrace, "duplicate") + ps(idx) = if ((fields(idx).optional || fields(idx).transient) && fields(idx).defaultValue.isDefined) { + val opt = BsonDecoder.option(tc).fromBsonValueUnsafe(value, nextTrace, nextCtx) + opt.getOrElse(fields(idx).defaultValue.get) + } else { + tc.fromBsonValueUnsafe(value, nextTrace, nextCtx) + } + } else if (noExtra && !ctx.ignoreExtraField.contains(name)) + throw BsonDecoder.Error(BsonTrace.Field(name) :: trace, "Invalid extra field.") + } - var i = 0 - while (i < len) { - if (ps(i) == null) { - ps(i) = if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) { - fields(i).defaultValue.get - } else { - tcs(i).decodeMissingUnsafe(spans(i) :: trace) + var i = 0 + while (i < len) { + if (ps(i) == null) { + ps(i) = if ((fields(i).optional || fields(i).transient) && fields(i).defaultValue.isDefined) { + fields(i).defaultValue.get + } else { + tcs(i).decodeMissingUnsafe(spans(i) :: trace) + } } + i += 1 } - i += 1 - } - Unsafe.unsafe { implicit u => - caseClassSchema.construct(Chunk.fromArray(ps)) match { - case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err") - case Right(value) => value + Unsafe.unsafe { implicit u => + caseClassSchema.construct(Chunk.fromArray(ps)) match { + case Left(err) => throw BsonDecoder.Error(trace, s"Failed to construct case class: $err") + case Right(value) => value + } } } } diff --git a/zio-schema-bson/src/main/scala/zio/schema/codec/package.scala b/zio-schema-bson/src/main/scala/zio/schema/codec/package.scala new file mode 100644 index 000000000..162189af8 --- /dev/null +++ b/zio-schema-bson/src/main/scala/zio/schema/codec/package.scala @@ -0,0 +1,22 @@ +package zio.schema.codec + +import org.bson.types.ObjectId + +import zio.schema.{ Schema, TypeId } + +package object bson { + val ObjectIdTag = "$oid" + + implicit val ObjectIdSchema: Schema[ObjectId] = + Schema.CaseClass1[String, ObjectId]( + id0 = TypeId.fromTypeName("ObjectId"), + field0 = Schema.Field( + name0 = ObjectIdTag, + schema0 = Schema[String], + get0 = _.toHexString, + set0 = (_, idStr) => new ObjectId(idStr) + ), + defaultConstruct0 = new ObjectId(_) + ) + +} diff --git a/zio-schema-bson/src/test/scala/zio/schema/codec/BsonSchemaCodecSpec.scala b/zio-schema-bson/src/test/scala/zio/schema/codec/BsonSchemaCodecSpec.scala index 106f8bbec..bd1392dbf 100644 --- a/zio-schema-bson/src/test/scala/zio/schema/codec/BsonSchemaCodecSpec.scala +++ b/zio-schema-bson/src/test/scala/zio/schema/codec/BsonSchemaCodecSpec.scala @@ -7,6 +7,7 @@ import org.bson.codecs.configuration.CodecRegistry import org.bson.codecs.{ Codec => BCodec, DecoderContext, EncoderContext } import org.bson.conversions.Bson import org.bson.io.BasicOutputBuffer +import org.bson.types.ObjectId import zio.bson.BsonBuilder._ import zio.bson._ @@ -50,6 +51,34 @@ object BsonSchemaCodecSpec extends ZIOSpecDefault { implicit lazy val codec: BsonCodec[EnumLike] = BsonSchemaCodec.bsonCodec(schema) } + case class CustomerId(value: ObjectId) extends AnyVal + case class Customer(id: CustomerId, name: String, age: Int, invitedFriends: List[CustomerId]) + + object Customer { + implicit lazy val customerIdSchema: Schema[CustomerId] = bson.ObjectIdSchema.transform(CustomerId(_), _.value) + + implicit lazy val customerSchema: Schema[Customer] = DeriveSchema.gen[Customer] + implicit lazy val customerCodec: BsonCodec[Customer] = BsonSchemaCodec.bsonCodec(customerSchema) + + val example: Customer = Customer( + id = CustomerId(ObjectId.get), + name = "Joseph", + age = 18, + invitedFriends = List(CustomerId(ObjectId.get), CustomerId(ObjectId.get)) + ) + + lazy val genCustomerId: Gen[Any, CustomerId] = + Gen.vectorOfN(12)(Gen.byte).map(bs => new ObjectId(bs.toArray)).map(CustomerId.apply) + + def gen: Gen[Sized, Customer] = + for { + id <- genCustomerId + name <- Gen.string + age <- Gen.int + friends <- Gen.listOf(genCustomerId) + } yield Customer(id, name, age, friends) + } + def spec: Spec[TestEnvironment with Scope, Any] = suite("BsonSchemaCodecSpec")( suite("round trip")( roundTripTest("SimpleClass")( @@ -66,6 +95,18 @@ object BsonSchemaCodecSpec extends ZIOSpecDefault { Gen.fromIterable(Chunk(EnumLike.A, EnumLike.B)), EnumLike.A, str("A") + ), + roundTripTest("Customer")( + Customer.gen, + Customer.example, + doc( + "id" -> Customer.example.id.value.toBsonValue, + "name" -> str(Customer.example.name), + "age" -> int(Customer.example.age), + "invitedFriends" -> array( + Customer.example.invitedFriends.map(_.value.toBsonValue): _* + ) + ) ) ), suite("configuration")(