From a6ff6a915754ed12ca700822ad890189a3a466ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pa=C5=82ka?= Date: Tue, 27 Sep 2022 18:13:38 +0200 Subject: [PATCH] Support conditional columns with `when` --- src/main/ColumnOp.scala | 6 +- src/main/DataFrame.scala | 8 +- src/main/DataFrameBuilders.scala | 7 +- src/main/UntypedOps.scala | 8 +- src/main/When.scala | 14 +++ src/main/api/api.scala | 2 +- src/main/functions/lit.scala | 2 +- src/main/functions/when.scala | 4 + src/main/types/Coerce.scala | 17 +++ src/main/types/DataType.scala | 180 +----------------------------- src/main/types/Encoder.scala | 181 +++++++++++++++++++++++++++++++ src/test/CoerceTest.scala | 26 +++++ src/test/WhenTest.scala | 59 ++++++++++ 13 files changed, 318 insertions(+), 196 deletions(-) create mode 100644 src/main/When.scala create mode 100644 src/main/functions/when.scala create mode 100644 src/main/types/Coerce.scala create mode 100644 src/main/types/Encoder.scala create mode 100644 src/test/CoerceTest.scala create mode 100644 src/test/WhenTest.scala diff --git a/src/main/ColumnOp.scala b/src/main/ColumnOp.scala index ce5a3ed..7aca58e 100644 --- a/src/main/ColumnOp.scala +++ b/src/main/ColumnOp.scala @@ -16,7 +16,7 @@ object ColumnOp: given numericNonNullable[T1 <: NumericType, T2 <: NumericType]: Plus[T1, T2] with type Out = DataType.CommonNumericNonNullableType[T1, T2] given numericNullable[T1 <: NumericOptType, T2 <: NumericOptType]: Plus[T1, T2] with - type Out = DataType.CommonNumericOptType[T1, T2] + type Out = DataType.CommonNumericNullableType[T1, T2] trait Minus[T1 <: DataType, T2 <: DataType]: type Out <: DataType @@ -25,7 +25,7 @@ object ColumnOp: given numericNonNullable[T1 <: NumericType, T2 <: NumericType]: Minus[T1, T2] with type Out = DataType.CommonNumericNonNullableType[T1, T2] given numericNullable[T1 <: NumericOptType, T2 <: NumericOptType]: Minus[T1, T2] with - type Out = DataType.CommonNumericOptType[T1, T2] + type Out = DataType.CommonNumericNullableType[T1, T2] trait Mult[T1 <: DataType, T2 <: DataType]: type Out <: DataType @@ -34,7 +34,7 @@ object ColumnOp: given numericNonNullable[T1 <: NumericType, T2 <: NumericType]: Mult[T1, T2] with type Out = DataType.CommonNumericNonNullableType[T1, T2] given numericNullable[T1 <: NumericOptType, T2 <: NumericOptType]: Mult[T1, T2] with - type Out = DataType.CommonNumericOptType[T1, T2] + type Out = DataType.CommonNumericNullableType[T1, T2] trait Div[T1 <: DataType, T2 <: DataType]: type Out <: DataType diff --git a/src/main/DataFrame.scala b/src/main/DataFrame.scala index 3b28b61..25055ff 100644 --- a/src/main/DataFrame.scala +++ b/src/main/DataFrame.scala @@ -3,7 +3,7 @@ package org.virtuslab.iskra import org.apache.spark.sql import org.apache.spark.sql.SparkSession import scala.quoted.* -import types.{DataType, StructType} +import types.{Encoder, StructEncoder} class DataFrame[Schema](val untyped: UntypedDataFrame): type Alias @@ -47,9 +47,9 @@ object DataFrame: // TODO: Use only a subset of columns private def collectAsImpl[FrameSchema : Type, A : Type](df: Expr[DataFrame[FrameSchema]])(using Quotes): Expr[List[A]] = - Expr.summon[DataType.Encoder[A]] match + Expr.summon[Encoder[A]] match case Some(encoder) => encoder match - case '{ $enc: DataType.StructEncoder[A] { type StructSchema = structSchema } } => + case '{ $enc: StructEncoder[A] { type StructSchema = structSchema } } => Type.of[MacroHelpers.AsTuple[FrameSchema]] match case '[`structSchema`] => '{ ${ df }.untyped.collect.toList.map(row => ${ enc }.decode(row).asInstanceOf[A]) } @@ -58,7 +58,7 @@ object DataFrame: val structColumns = allColumns(Type.of[structSchema]) val errorMsg = s"A data frame with columns:\n${showColumns(frameColumns)}\ncannot be collected as a list of ${Type.show[A]}, which would be encoded as a row with columns:\n${showColumns(structColumns)}" quotes.reflect.report.errorAndAbort(errorMsg) - case '{ $enc: DataType.Encoder[A] { type ColumnType = colType } } => + case '{ $enc: Encoder[A] { type ColumnType = colType } } => def fromDataType[T : Type] = Type.of[T] match case '[`colType`] => diff --git a/src/main/DataFrameBuilders.scala b/src/main/DataFrameBuilders.scala index ee92fbc..cddd78b 100644 --- a/src/main/DataFrameBuilders.scala +++ b/src/main/DataFrameBuilders.scala @@ -4,8 +4,7 @@ import scala.quoted._ import org.apache.spark.sql import org.apache.spark.sql.SparkSession import org.virtuslab.iskra.DataFrame -import org.virtuslab.iskra.types.{DataType, StructType} -import DataType.{Encoder, StructEncoder, PrimitiveEncoder} +import org.virtuslab.iskra.types.{DataType, StructType, Encoder, StructEncoder, PrimitiveEncoder} object DataFrameBuilders: extension [A](seq: Seq[A])(using encoder: Encoder[A]) @@ -13,11 +12,11 @@ object DataFrameBuilders: private def toTypedDFImpl[A : Type](seq: Expr[Seq[A]], encoder: Expr[Encoder[A]], spark: Expr[SparkSession])(using Quotes) = val (schemaType, schema, encodeFun) = encoder match - case '{ $e: DataType.StructEncoder.Aux[A, t] } => + case '{ $e: StructEncoder.Aux[A, t] } => val schema = '{ ${ e }.catalystType } val encodeFun: Expr[A => sql.Row] = '{ ${ e }.encode } (Type.of[t], schema, encodeFun) - case '{ $e: DataType.Encoder.Aux[tpe, t] } => + case '{ $e: Encoder.Aux[tpe, t] } => val schema = '{ sql.types.StructType(Seq( sql.types.StructField("value", ${ encoder }.catalystType, ${ encoder }.isNullable ) diff --git a/src/main/UntypedOps.scala b/src/main/UntypedOps.scala index e5a03fb..4541d0c 100644 --- a/src/main/UntypedOps.scala +++ b/src/main/UntypedOps.scala @@ -1,16 +1,16 @@ package org.virtuslab.iskra import scala.quoted.* -import types.{DataType, StructType} +import types.{DataType, Encoder, StructType, StructEncoder} object UntypedOps: extension (untyped: UntypedColumn) def typed[A <: DataType] = Column[A](untyped) extension (df: UntypedDataFrame) - transparent inline def typed[A](using encoder: DataType.StructEncoder[A]): DataFrame[?] = ${ typedDataFrameImpl('df, 'encoder) } // TODO: Check schema at runtime? Check if names of columns match? + transparent inline def typed[A](using encoder: StructEncoder[A]): DataFrame[?] = ${ typedDataFrameImpl('df, 'encoder) } // TODO: Check schema at runtime? Check if names of columns match? - private def typedDataFrameImpl[A : Type](df: Expr[UntypedDataFrame], encoder: Expr[DataType.StructEncoder[A]])(using Quotes) = + private def typedDataFrameImpl[A : Type](df: Expr[UntypedDataFrame], encoder: Expr[StructEncoder[A]])(using Quotes) = encoder match - case '{ ${e}: DataType.Encoder.Aux[tpe, StructType[t]] } => + case '{ ${e}: Encoder.Aux[tpe, StructType[t]] } => '{ DataFrame[t](${ df }) } diff --git a/src/main/When.scala b/src/main/When.scala new file mode 100644 index 0000000..b8c3b0f --- /dev/null +++ b/src/main/When.scala @@ -0,0 +1,14 @@ +package org.virtuslab.iskra + +import org.apache.spark.sql.{functions => f, Column => UntypedColumn} +import org.virtuslab.iskra.types.{Coerce, DataType, BooleanOptType} + +object When: + class WhenColumn[T <: DataType](untyped: UntypedColumn) extends Column[DataType.Nullable[T]](untyped): + def when[U <: DataType](condition: Column[BooleanOptType], value: Column[U])(using coerce: Coerce[T, U]): WhenColumn[coerce.Coerced] = + WhenColumn(this.untyped.when(condition.untyped, value.untyped)) + def otherwise[U <: DataType](value: Column[U])(using coerce: Coerce[T, U]): Column[coerce.Coerced] = + Column(this.untyped.otherwise(value.untyped)) + + def when[T <: DataType](condition: Column[BooleanOptType], value: Column[T]): WhenColumn[T] = + WhenColumn(f.when(condition.untyped, value.untyped)) diff --git a/src/main/api/api.scala b/src/main/api/api.scala index d509ff9..13f0be3 100644 --- a/src/main/api/api.scala +++ b/src/main/api/api.scala @@ -28,7 +28,7 @@ export org.virtuslab.iskra.$ export org.virtuslab.iskra.{Column, DataFrame, UntypedColumn, UntypedDataFrame, :=, /} object functions: - export org.virtuslab.iskra.functions.lit + export org.virtuslab.iskra.functions.{lit, when} export org.virtuslab.iskra.functions.Aggregates.* export org.apache.spark.sql.SparkSession diff --git a/src/main/functions/lit.scala b/src/main/functions/lit.scala index ffe5c7f..737bc49 100644 --- a/src/main/functions/lit.scala +++ b/src/main/functions/lit.scala @@ -2,6 +2,6 @@ package org.virtuslab.iskra.functions import org.apache.spark.sql import org.virtuslab.iskra.Column -import org.virtuslab.iskra.types.DataType.PrimitiveEncoder +import org.virtuslab.iskra.types.PrimitiveEncoder def lit[A](value: A)(using encoder: PrimitiveEncoder[A]): Column[encoder.ColumnType] = Column(sql.functions.lit(encoder.encode(value))) diff --git a/src/main/functions/when.scala b/src/main/functions/when.scala new file mode 100644 index 0000000..d942418 --- /dev/null +++ b/src/main/functions/when.scala @@ -0,0 +1,4 @@ +package org.virtuslab.iskra +package functions + +export When.when diff --git a/src/main/types/Coerce.scala b/src/main/types/Coerce.scala new file mode 100644 index 0000000..5c61468 --- /dev/null +++ b/src/main/types/Coerce.scala @@ -0,0 +1,17 @@ +package org.virtuslab.iskra +package types + +import DataType.{CommonNumericNonNullableType, CommonNumericNullableType, NumericOptType, NumericType} + +trait Coerce[-A <: DataType, -B <: DataType]: + type Coerced <: DataType + +object Coerce: + given sameType[A <: DataType]: Coerce[A, A] with + override type Coerced = A + + given nullable[A <: NumericOptType, B <: NumericOptType]: Coerce[A, B] with + override type Coerced = CommonNumericNullableType[A, B] + + given nonNullable[A <: NumericType, B <: NumericType]: Coerce[A, B] with + override type Coerced = CommonNumericNonNullableType[A, B] diff --git a/src/main/types/DataType.scala b/src/main/types/DataType.scala index 7e5c0e0..1593d7d 100644 --- a/src/main/types/DataType.scala +++ b/src/main/types/DataType.scala @@ -1,11 +1,6 @@ package org.virtuslab.iskra package types -import scala.quoted._ -import scala.deriving.Mirror -import org.apache.spark.sql -import MacroHelpers.TupleSubtype - sealed trait DataType object DataType: @@ -38,7 +33,7 @@ object DataType: case DoubleOptType => DoubleOptType case StructOptType[schema] => StructOptType[schema] - type CommonNumericOptType[T1 <: DataType, T2 <: DataType] <: NumericOptType = (T1, T2) match + type CommonNumericNullableType[T1 <: DataType, T2 <: DataType] <: NumericOptType = (T1, T2) match case (DoubleOptType, _) | (_, DoubleOptType) => DoubleOptType case (FloatOptType, _) | (_, FloatOptType) => FloatOptType case (LongOptType, _) | (_, LongOptType) => LongOptType @@ -54,179 +49,6 @@ object DataType: case (ShortOptType, _) | (_, ShortOptType) => ShortType case (ByteOptType, _) | (_, ByteOptType) => ByteType - trait Encoder[-A]: - type ColumnType <: DataType - def encode(value: A): Any - def decode(value: Any): Any - def catalystType: sql.types.DataType - def isNullable: Boolean - - trait PrimitiveEncoder[-A] extends Encoder[A] - - trait PrimitiveNullableEncoder[-A] extends PrimitiveEncoder[Option[A]]: - def encode(value: Option[A]) = value.orNull - def decode(value: Any) = Option(value) - def isNullable = true - - trait PrimitiveNonNullableEncoder[-A] extends PrimitiveEncoder[A]: - def encode(value: A) = value - def decode(value: Any) = value - def isNullable = true - - - object Encoder: - type Aux[-A, E <: DataType] = Encoder[A] { type ColumnType = E } - - inline given boolean: PrimitiveNonNullableEncoder[Boolean] with - type ColumnType = BooleanType - def catalystType = sql.types.BooleanType - inline given booleanOpt: PrimitiveNullableEncoder[Boolean] with - type ColumnType = BooleanOptType - def catalystType = sql.types.BooleanType - - inline given string: PrimitiveNonNullableEncoder[String] with - type ColumnType = StringType - def catalystType = sql.types.StringType - inline given stringOpt: PrimitiveNullableEncoder[String] with - type ColumnType = StringOptType - def catalystType = sql.types.StringType - - inline given byte: PrimitiveNonNullableEncoder[Byte] with - type ColumnType = ByteType - def catalystType = sql.types.ByteType - inline given byteOpt: PrimitiveNullableEncoder[Byte] with - type ColumnType = ByteOptType - def catalystType = sql.types.ByteType - - inline given short: PrimitiveNonNullableEncoder[Short] with - type ColumnType = ShortType - def catalystType = sql.types.ShortType - inline given shortOpt: PrimitiveNullableEncoder[Short] with - type ColumnType = ShortOptType - def catalystType = sql.types.ShortType - - inline given int: PrimitiveNonNullableEncoder[Int] with - type ColumnType = IntegerType - def catalystType = sql.types.IntegerType - inline given intOpt: PrimitiveNullableEncoder[Int] with - type ColumnType = IntegerOptType - def catalystType = sql.types.IntegerType - - inline given long: PrimitiveNonNullableEncoder[Long] with - type ColumnType = LongType - def catalystType = sql.types.LongType - inline given longOpt: PrimitiveNullableEncoder[Long] with - type ColumnType = LongOptType - def catalystType = sql.types.LongType - - inline given float: PrimitiveNonNullableEncoder[Float] with - type ColumnType = FloatType - def catalystType = sql.types.FloatType - inline given floatOpt: PrimitiveNullableEncoder[Float] with - type ColumnType = FloatOptType - def catalystType = sql.types.FloatType - - inline given double: PrimitiveNonNullableEncoder[Double] with - type ColumnType = DoubleType - def catalystType = sql.types.DoubleType - inline given doubleOpt: PrimitiveNullableEncoder[Double] with - type ColumnType = DoubleOptType - def catalystType = sql.types.DoubleType - - export StructEncoder.{fromMirror, optFromMirror} - - trait StructEncoder[-A] extends Encoder[A]: - type StructSchema <: Tuple - type ColumnType = StructType[StructSchema] - override def catalystType: sql.types.StructType - override def encode(a: A): sql.Row - - object StructEncoder: - type Aux[-A, E <: Tuple] = StructEncoder[A] { type StructSchema = E } - - private case class ColumnInfo( - labelType: Type[?], - labelValue: String, - decodedType: Type[?], - encoder: Expr[Encoder[?]] - ) - - private def getColumnSchemaType(using quotes: Quotes)(subcolumnInfos: List[ColumnInfo]): Type[?] = - subcolumnInfos match - case Nil => Type.of[EmptyTuple] - case info :: tail => - info.labelType match - case '[Name.Subtype[lt]] => - info.encoder match - case '{ ${encoder}: Encoder.Aux[tpe, DataType.Subtype[e]] } => - getColumnSchemaType(tail) match - case '[TupleSubtype[tailType]] => - Type.of[(lt := e) *: tailType] - - private def getSubcolumnInfos(elemLabels: Type[?], elemTypes: Type[?])(using Quotes): List[ColumnInfo] = - import quotes.reflect.Select - elemLabels match - case '[EmptyTuple] => Nil - case '[label *: labels] => - val labelValue = Type.valueOfConstant[label].get.toString - elemTypes match - case '[tpe *: types] => - Expr.summon[Encoder[tpe]] match - case Some(encoderExpr) => - ColumnInfo(Type.of[label], labelValue, Type.of[tpe], encoderExpr) :: getSubcolumnInfos(Type.of[labels], Type.of[types]) - case _ => quotes.reflect.report.errorAndAbort(s"Could not summon encoder for ${Type.show[tpe]}") - - transparent inline given fromMirror[A]: StructEncoder[A] = ${ fromMirrorImpl[A] } - - def fromMirrorImpl[A : Type](using q: Quotes): Expr[StructEncoder[A]] = - Expr.summon[Mirror.Of[A]].getOrElse(throw new Exception(s"Could not find Mirror when generating encoder for ${Type.show[A]}")) match - case '{ ${mirror}: Mirror.ProductOf[A] { type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes } } => - val subcolumnInfos = getSubcolumnInfos(Type.of[elementLabels], Type.of[elementTypes]) - val columnSchemaType = getColumnSchemaType(subcolumnInfos) - - val structFieldExprs = subcolumnInfos.map { info => - '{ sql.types.StructField(${Expr(info.labelValue)}, ${info.encoder}.catalystType, ${info.encoder}.isNullable) } - } - val structFields = Expr.ofSeq(structFieldExprs) - - def rowElements(value: Expr[A]) = - subcolumnInfos.map { info => - import quotes.reflect.* - info.decodedType match - case '[t] => - '{ ${info.encoder}.asInstanceOf[Encoder[t]].encode(${ Select.unique(value.asTerm, info.labelValue).asExprOf[t] }) } - } - - def rowElementsTuple(row: Expr[sql.Row]): Expr[Tuple] = - val elements = subcolumnInfos.zipWithIndex.map { (info, idx) => - given Quotes = q - '{ ${info.encoder}.decode(${row}.get(${Expr(idx)})) } - } - Expr.ofTupleFromSeq(elements) - - columnSchemaType match - case '[TupleSubtype[t]] => - '{ - (new StructEncoder[A] { - override type StructSchema = t - override def catalystType = sql.types.StructType(${ structFields }) - override def isNullable = false - override def encode(a: A) = - sql.Row.fromSeq(${ Expr.ofSeq(rowElements('a)) }) - override def decode(a: Any) = - ${mirror}.fromProduct(${ rowElementsTuple('{a.asInstanceOf[sql.Row]}) }) - }): StructEncoder[A] { type StructSchema = t } - } - end fromMirrorImpl - - inline given optFromMirror[A](using encoder: StructEncoder[A]): (Encoder[Option[A]] { type ColumnType = StructOptType[encoder.StructSchema] }) = - new Encoder[Option[A]]: - override type ColumnType = StructOptType[encoder.StructSchema] - override def encode(value: Option[A]): Any = value.map(encoder.encode).orNull - override def decode(value: Any): Any = Option(encoder.decode) - override def catalystType = encoder.catalystType - override def isNullable = true - import DataType.NotNull sealed class BooleanOptType extends DataType diff --git a/src/main/types/Encoder.scala b/src/main/types/Encoder.scala new file mode 100644 index 0000000..6ef6818 --- /dev/null +++ b/src/main/types/Encoder.scala @@ -0,0 +1,181 @@ +package org.virtuslab.iskra +package types + +import scala.quoted._ +import scala.deriving.Mirror +import org.apache.spark.sql +import MacroHelpers.TupleSubtype + + +trait Encoder[-A]: + type ColumnType <: DataType + def encode(value: A): Any + def decode(value: Any): Any + def catalystType: sql.types.DataType + def isNullable: Boolean + +trait PrimitiveEncoder[-A] extends Encoder[A] + +trait PrimitiveNullableEncoder[-A] extends PrimitiveEncoder[Option[A]]: + def encode(value: Option[A]) = value.orNull + def decode(value: Any) = Option(value) + def isNullable = true + +trait PrimitiveNonNullableEncoder[-A] extends PrimitiveEncoder[A]: + def encode(value: A) = value + def decode(value: Any) = value + def isNullable = true + + +object Encoder: + type Aux[-A, E <: DataType] = Encoder[A] { type ColumnType = E } + + inline given boolean: PrimitiveNonNullableEncoder[Boolean] with + type ColumnType = BooleanType + def catalystType = sql.types.BooleanType + inline given booleanOpt: PrimitiveNullableEncoder[Boolean] with + type ColumnType = BooleanOptType + def catalystType = sql.types.BooleanType + + inline given string: PrimitiveNonNullableEncoder[String] with + type ColumnType = StringType + def catalystType = sql.types.StringType + inline given stringOpt: PrimitiveNullableEncoder[String] with + type ColumnType = StringOptType + def catalystType = sql.types.StringType + + inline given byte: PrimitiveNonNullableEncoder[Byte] with + type ColumnType = ByteType + def catalystType = sql.types.ByteType + inline given byteOpt: PrimitiveNullableEncoder[Byte] with + type ColumnType = ByteOptType + def catalystType = sql.types.ByteType + + inline given short: PrimitiveNonNullableEncoder[Short] with + type ColumnType = ShortType + def catalystType = sql.types.ShortType + inline given shortOpt: PrimitiveNullableEncoder[Short] with + type ColumnType = ShortOptType + def catalystType = sql.types.ShortType + + inline given int: PrimitiveNonNullableEncoder[Int] with + type ColumnType = IntegerType + def catalystType = sql.types.IntegerType + inline given intOpt: PrimitiveNullableEncoder[Int] with + type ColumnType = IntegerOptType + def catalystType = sql.types.IntegerType + + inline given long: PrimitiveNonNullableEncoder[Long] with + type ColumnType = LongType + def catalystType = sql.types.LongType + inline given longOpt: PrimitiveNullableEncoder[Long] with + type ColumnType = LongOptType + def catalystType = sql.types.LongType + + inline given float: PrimitiveNonNullableEncoder[Float] with + type ColumnType = FloatType + def catalystType = sql.types.FloatType + inline given floatOpt: PrimitiveNullableEncoder[Float] with + type ColumnType = FloatOptType + def catalystType = sql.types.FloatType + + inline given double: PrimitiveNonNullableEncoder[Double] with + type ColumnType = DoubleType + def catalystType = sql.types.DoubleType + inline given doubleOpt: PrimitiveNullableEncoder[Double] with + type ColumnType = DoubleOptType + def catalystType = sql.types.DoubleType + + export StructEncoder.{fromMirror, optFromMirror} + +trait StructEncoder[-A] extends Encoder[A]: + type StructSchema <: Tuple + type ColumnType = StructType[StructSchema] + override def catalystType: sql.types.StructType + override def encode(a: A): sql.Row + +object StructEncoder: + type Aux[-A, E <: Tuple] = StructEncoder[A] { type StructSchema = E } + + private case class ColumnInfo( + labelType: Type[?], + labelValue: String, + decodedType: Type[?], + encoder: Expr[Encoder[?]] + ) + + private def getColumnSchemaType(using quotes: Quotes)(subcolumnInfos: List[ColumnInfo]): Type[?] = + subcolumnInfos match + case Nil => Type.of[EmptyTuple] + case info :: tail => + info.labelType match + case '[Name.Subtype[lt]] => + info.encoder match + case '{ ${encoder}: Encoder.Aux[tpe, DataType.Subtype[e]] } => + getColumnSchemaType(tail) match + case '[TupleSubtype[tailType]] => + Type.of[(lt := e) *: tailType] + + private def getSubcolumnInfos(elemLabels: Type[?], elemTypes: Type[?])(using Quotes): List[ColumnInfo] = + import quotes.reflect.Select + elemLabels match + case '[EmptyTuple] => Nil + case '[label *: labels] => + val labelValue = Type.valueOfConstant[label].get.toString + elemTypes match + case '[tpe *: types] => + Expr.summon[Encoder[tpe]] match + case Some(encoderExpr) => + ColumnInfo(Type.of[label], labelValue, Type.of[tpe], encoderExpr) :: getSubcolumnInfos(Type.of[labels], Type.of[types]) + case _ => quotes.reflect.report.errorAndAbort(s"Could not summon encoder for ${Type.show[tpe]}") + + transparent inline given fromMirror[A]: StructEncoder[A] = ${ fromMirrorImpl[A] } + + def fromMirrorImpl[A : Type](using q: Quotes): Expr[StructEncoder[A]] = + Expr.summon[Mirror.Of[A]].getOrElse(throw new Exception(s"Could not find Mirror when generating encoder for ${Type.show[A]}")) match + case '{ ${mirror}: Mirror.ProductOf[A] { type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes } } => + val subcolumnInfos = getSubcolumnInfos(Type.of[elementLabels], Type.of[elementTypes]) + val columnSchemaType = getColumnSchemaType(subcolumnInfos) + + val structFieldExprs = subcolumnInfos.map { info => + '{ sql.types.StructField(${Expr(info.labelValue)}, ${info.encoder}.catalystType, ${info.encoder}.isNullable) } + } + val structFields = Expr.ofSeq(structFieldExprs) + + def rowElements(value: Expr[A]) = + subcolumnInfos.map { info => + import quotes.reflect.* + info.decodedType match + case '[t] => + '{ ${info.encoder}.asInstanceOf[Encoder[t]].encode(${ Select.unique(value.asTerm, info.labelValue).asExprOf[t] }) } + } + + def rowElementsTuple(row: Expr[sql.Row]): Expr[Tuple] = + val elements = subcolumnInfos.zipWithIndex.map { (info, idx) => + given Quotes = q + '{ ${info.encoder}.decode(${row}.get(${Expr(idx)})) } + } + Expr.ofTupleFromSeq(elements) + + columnSchemaType match + case '[TupleSubtype[t]] => + '{ + (new StructEncoder[A] { + override type StructSchema = t + override def catalystType = sql.types.StructType(${ structFields }) + override def isNullable = false + override def encode(a: A) = + sql.Row.fromSeq(${ Expr.ofSeq(rowElements('a)) }) + override def decode(a: Any) = + ${mirror}.fromProduct(${ rowElementsTuple('{a.asInstanceOf[sql.Row]}) }) + }): StructEncoder[A] { type StructSchema = t } + } + end fromMirrorImpl + + inline given optFromMirror[A](using encoder: StructEncoder[A]): (Encoder[Option[A]] { type ColumnType = StructOptType[encoder.StructSchema] }) = + new Encoder[Option[A]]: + override type ColumnType = StructOptType[encoder.StructSchema] + override def encode(value: Option[A]): Any = value.map(encoder.encode).orNull + override def decode(value: Any): Any = Option(encoder.decode) + override def catalystType = encoder.catalystType + override def isNullable = true diff --git a/src/test/CoerceTest.scala b/src/test/CoerceTest.scala new file mode 100644 index 0000000..cf3ca04 --- /dev/null +++ b/src/test/CoerceTest.scala @@ -0,0 +1,26 @@ +package org.virtuslab.iskra +package test + +import org.scalatest.funsuite.AnyFunSuite +import types.* + +class CoerceTest extends AnyFunSuite: + test("coerce-int-double") { + val c = summon[Coerce[IntegerType, DoubleType]] + summon[c.Coerced =:= DoubleType] + } + + test("coerce-short-short-opt") { + val c = summon[Coerce[ShortType, ShortOptType]] + summon[c.Coerced =:= ShortOptType] + } + + test("coerce-long-byte-opt") { + val c = summon[Coerce[LongType, ByteOptType]] + summon[c.Coerced =:= LongOptType] + } + + test("coerce-string-string-opt") { + val c = summon[Coerce[StringType, StringOptType]] + summon[c.Coerced =:= StringOptType] + } diff --git a/src/test/WhenTest.scala b/src/test/WhenTest.scala new file mode 100644 index 0000000..3f59969 --- /dev/null +++ b/src/test/WhenTest.scala @@ -0,0 +1,59 @@ +package org.virtuslab.iskra.test + +class WhenTest extends SparkUnitTest: + import org.virtuslab.iskra.api.* + import functions.{lit, when} + import Column.=== // by default shadowed by === from scalatest + + case class Foo(int: Int) + + val foos = Seq( + Foo(1), + Foo(2), + Foo(3) + ).toTypedDF + + test("when-without-fallback") { + val result = foos + .select(when($.int === lit(1), lit("a")).as("strOpt")) + .collectAs[Option[String]] + + result shouldEqual Seq(Some("a"), None, None) + } + + test("when-with-fallback") { + val result = foos + .select{ + when($.int === lit(1), lit(10)) + .otherwise(lit(100d)) + .as("double") + } + .collectAs[Double] + + result shouldEqual Seq(10d, 100d, 100d) + } + + test("when-else-when-without-fallback") { + val result = foos + .select{ + when($.int === lit(1), lit(10)) + .when($.int === lit(2), lit(100L)) + .as("longOpt") + } + .collectAs[Option[Long]] + + result shouldEqual Seq(Some(10L), Some(100L), None) + } + + test("when-else-when-with-fallback") { + val result = foos + .select{ + when($.int === lit(1), lit(10)) + .when($.int === lit(2), lit(100L)) + .otherwise(lit(1000d)) + .as("str") + } + .collectAs[Option[Double]] + + result shouldEqual Seq(Some(10d), Some(100d), Some(1000d)) + }