Skip to content

Commit

Permalink
Support conditional columns with when
Browse files Browse the repository at this point in the history
  • Loading branch information
prolativ committed Sep 27, 2022
1 parent 75fbcee commit a6ff6a9
Show file tree
Hide file tree
Showing 13 changed files with 318 additions and 196 deletions.
6 changes: 3 additions & 3 deletions src/main/ColumnOp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ object ColumnOp:
given numericNonNullable[T1 <: NumericType, T2 <: NumericType]: Plus[T1, T2] with
type Out = DataType.CommonNumericNonNullableType[T1, T2]
given numericNullable[T1 <: NumericOptType, T2 <: NumericOptType]: Plus[T1, T2] with
type Out = DataType.CommonNumericOptType[T1, T2]
type Out = DataType.CommonNumericNullableType[T1, T2]

trait Minus[T1 <: DataType, T2 <: DataType]:
type Out <: DataType
Expand All @@ -25,7 +25,7 @@ object ColumnOp:
given numericNonNullable[T1 <: NumericType, T2 <: NumericType]: Minus[T1, T2] with
type Out = DataType.CommonNumericNonNullableType[T1, T2]
given numericNullable[T1 <: NumericOptType, T2 <: NumericOptType]: Minus[T1, T2] with
type Out = DataType.CommonNumericOptType[T1, T2]
type Out = DataType.CommonNumericNullableType[T1, T2]

trait Mult[T1 <: DataType, T2 <: DataType]:
type Out <: DataType
Expand All @@ -34,7 +34,7 @@ object ColumnOp:
given numericNonNullable[T1 <: NumericType, T2 <: NumericType]: Mult[T1, T2] with
type Out = DataType.CommonNumericNonNullableType[T1, T2]
given numericNullable[T1 <: NumericOptType, T2 <: NumericOptType]: Mult[T1, T2] with
type Out = DataType.CommonNumericOptType[T1, T2]
type Out = DataType.CommonNumericNullableType[T1, T2]

trait Div[T1 <: DataType, T2 <: DataType]:
type Out <: DataType
Expand Down
8 changes: 4 additions & 4 deletions src/main/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.virtuslab.iskra
import org.apache.spark.sql
import org.apache.spark.sql.SparkSession
import scala.quoted.*
import types.{DataType, StructType}
import types.{Encoder, StructEncoder}

class DataFrame[Schema](val untyped: UntypedDataFrame):
type Alias
Expand Down Expand Up @@ -47,9 +47,9 @@ object DataFrame:

// TODO: Use only a subset of columns
private def collectAsImpl[FrameSchema : Type, A : Type](df: Expr[DataFrame[FrameSchema]])(using Quotes): Expr[List[A]] =
Expr.summon[DataType.Encoder[A]] match
Expr.summon[Encoder[A]] match
case Some(encoder) => encoder match
case '{ $enc: DataType.StructEncoder[A] { type StructSchema = structSchema } } =>
case '{ $enc: StructEncoder[A] { type StructSchema = structSchema } } =>
Type.of[MacroHelpers.AsTuple[FrameSchema]] match
case '[`structSchema`] =>
'{ ${ df }.untyped.collect.toList.map(row => ${ enc }.decode(row).asInstanceOf[A]) }
Expand All @@ -58,7 +58,7 @@ object DataFrame:
val structColumns = allColumns(Type.of[structSchema])
val errorMsg = s"A data frame with columns:\n${showColumns(frameColumns)}\ncannot be collected as a list of ${Type.show[A]}, which would be encoded as a row with columns:\n${showColumns(structColumns)}"
quotes.reflect.report.errorAndAbort(errorMsg)
case '{ $enc: DataType.Encoder[A] { type ColumnType = colType } } =>
case '{ $enc: Encoder[A] { type ColumnType = colType } } =>
def fromDataType[T : Type] =
Type.of[T] match
case '[`colType`] =>
Expand Down
7 changes: 3 additions & 4 deletions src/main/DataFrameBuilders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@ import scala.quoted._
import org.apache.spark.sql
import org.apache.spark.sql.SparkSession
import org.virtuslab.iskra.DataFrame
import org.virtuslab.iskra.types.{DataType, StructType}
import DataType.{Encoder, StructEncoder, PrimitiveEncoder}
import org.virtuslab.iskra.types.{DataType, StructType, Encoder, StructEncoder, PrimitiveEncoder}

object DataFrameBuilders:
extension [A](seq: Seq[A])(using encoder: Encoder[A])
transparent inline def toTypedDF(using spark: SparkSession): DataFrame[?] = ${ toTypedDFImpl('seq, 'encoder, 'spark) }

private def toTypedDFImpl[A : Type](seq: Expr[Seq[A]], encoder: Expr[Encoder[A]], spark: Expr[SparkSession])(using Quotes) =
val (schemaType, schema, encodeFun) = encoder match
case '{ $e: DataType.StructEncoder.Aux[A, t] } =>
case '{ $e: StructEncoder.Aux[A, t] } =>
val schema = '{ ${ e }.catalystType }
val encodeFun: Expr[A => sql.Row] = '{ ${ e }.encode }
(Type.of[t], schema, encodeFun)
case '{ $e: DataType.Encoder.Aux[tpe, t] } =>
case '{ $e: Encoder.Aux[tpe, t] } =>
val schema = '{
sql.types.StructType(Seq(
sql.types.StructField("value", ${ encoder }.catalystType, ${ encoder }.isNullable )
Expand Down
8 changes: 4 additions & 4 deletions src/main/UntypedOps.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
package org.virtuslab.iskra

import scala.quoted.*
import types.{DataType, StructType}
import types.{DataType, Encoder, StructType, StructEncoder}

object UntypedOps:
extension (untyped: UntypedColumn)
def typed[A <: DataType] = Column[A](untyped)

extension (df: UntypedDataFrame)
transparent inline def typed[A](using encoder: DataType.StructEncoder[A]): DataFrame[?] = ${ typedDataFrameImpl('df, 'encoder) } // TODO: Check schema at runtime? Check if names of columns match?
transparent inline def typed[A](using encoder: StructEncoder[A]): DataFrame[?] = ${ typedDataFrameImpl('df, 'encoder) } // TODO: Check schema at runtime? Check if names of columns match?

private def typedDataFrameImpl[A : Type](df: Expr[UntypedDataFrame], encoder: Expr[DataType.StructEncoder[A]])(using Quotes) =
private def typedDataFrameImpl[A : Type](df: Expr[UntypedDataFrame], encoder: Expr[StructEncoder[A]])(using Quotes) =
encoder match
case '{ ${e}: DataType.Encoder.Aux[tpe, StructType[t]] } =>
case '{ ${e}: Encoder.Aux[tpe, StructType[t]] } =>
'{ DataFrame[t](${ df }) }
14 changes: 14 additions & 0 deletions src/main/When.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.virtuslab.iskra

import org.apache.spark.sql.{functions => f, Column => UntypedColumn}
import org.virtuslab.iskra.types.{Coerce, DataType, BooleanOptType}

object When:
class WhenColumn[T <: DataType](untyped: UntypedColumn) extends Column[DataType.Nullable[T]](untyped):
def when[U <: DataType](condition: Column[BooleanOptType], value: Column[U])(using coerce: Coerce[T, U]): WhenColumn[coerce.Coerced] =
WhenColumn(this.untyped.when(condition.untyped, value.untyped))
def otherwise[U <: DataType](value: Column[U])(using coerce: Coerce[T, U]): Column[coerce.Coerced] =
Column(this.untyped.otherwise(value.untyped))

def when[T <: DataType](condition: Column[BooleanOptType], value: Column[T]): WhenColumn[T] =
WhenColumn(f.when(condition.untyped, value.untyped))
2 changes: 1 addition & 1 deletion src/main/api/api.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export org.virtuslab.iskra.$
export org.virtuslab.iskra.{Column, DataFrame, UntypedColumn, UntypedDataFrame, :=, /}

object functions:
export org.virtuslab.iskra.functions.lit
export org.virtuslab.iskra.functions.{lit, when}
export org.virtuslab.iskra.functions.Aggregates.*

export org.apache.spark.sql.SparkSession
2 changes: 1 addition & 1 deletion src/main/functions/lit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ package org.virtuslab.iskra.functions

import org.apache.spark.sql
import org.virtuslab.iskra.Column
import org.virtuslab.iskra.types.DataType.PrimitiveEncoder
import org.virtuslab.iskra.types.PrimitiveEncoder

def lit[A](value: A)(using encoder: PrimitiveEncoder[A]): Column[encoder.ColumnType] = Column(sql.functions.lit(encoder.encode(value)))
4 changes: 4 additions & 0 deletions src/main/functions/when.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package org.virtuslab.iskra
package functions

export When.when
17 changes: 17 additions & 0 deletions src/main/types/Coerce.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.virtuslab.iskra
package types

import DataType.{CommonNumericNonNullableType, CommonNumericNullableType, NumericOptType, NumericType}

trait Coerce[-A <: DataType, -B <: DataType]:
type Coerced <: DataType

object Coerce:
given sameType[A <: DataType]: Coerce[A, A] with
override type Coerced = A

given nullable[A <: NumericOptType, B <: NumericOptType]: Coerce[A, B] with
override type Coerced = CommonNumericNullableType[A, B]

given nonNullable[A <: NumericType, B <: NumericType]: Coerce[A, B] with
override type Coerced = CommonNumericNonNullableType[A, B]
180 changes: 1 addition & 179 deletions src/main/types/DataType.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
package org.virtuslab.iskra
package types

import scala.quoted._
import scala.deriving.Mirror
import org.apache.spark.sql
import MacroHelpers.TupleSubtype

sealed trait DataType

object DataType:
Expand Down Expand Up @@ -38,7 +33,7 @@ object DataType:
case DoubleOptType => DoubleOptType
case StructOptType[schema] => StructOptType[schema]

type CommonNumericOptType[T1 <: DataType, T2 <: DataType] <: NumericOptType = (T1, T2) match
type CommonNumericNullableType[T1 <: DataType, T2 <: DataType] <: NumericOptType = (T1, T2) match
case (DoubleOptType, _) | (_, DoubleOptType) => DoubleOptType
case (FloatOptType, _) | (_, FloatOptType) => FloatOptType
case (LongOptType, _) | (_, LongOptType) => LongOptType
Expand All @@ -54,179 +49,6 @@ object DataType:
case (ShortOptType, _) | (_, ShortOptType) => ShortType
case (ByteOptType, _) | (_, ByteOptType) => ByteType

trait Encoder[-A]:
type ColumnType <: DataType
def encode(value: A): Any
def decode(value: Any): Any
def catalystType: sql.types.DataType
def isNullable: Boolean

trait PrimitiveEncoder[-A] extends Encoder[A]

trait PrimitiveNullableEncoder[-A] extends PrimitiveEncoder[Option[A]]:
def encode(value: Option[A]) = value.orNull
def decode(value: Any) = Option(value)
def isNullable = true

trait PrimitiveNonNullableEncoder[-A] extends PrimitiveEncoder[A]:
def encode(value: A) = value
def decode(value: Any) = value
def isNullable = true


object Encoder:
type Aux[-A, E <: DataType] = Encoder[A] { type ColumnType = E }

inline given boolean: PrimitiveNonNullableEncoder[Boolean] with
type ColumnType = BooleanType
def catalystType = sql.types.BooleanType
inline given booleanOpt: PrimitiveNullableEncoder[Boolean] with
type ColumnType = BooleanOptType
def catalystType = sql.types.BooleanType

inline given string: PrimitiveNonNullableEncoder[String] with
type ColumnType = StringType
def catalystType = sql.types.StringType
inline given stringOpt: PrimitiveNullableEncoder[String] with
type ColumnType = StringOptType
def catalystType = sql.types.StringType

inline given byte: PrimitiveNonNullableEncoder[Byte] with
type ColumnType = ByteType
def catalystType = sql.types.ByteType
inline given byteOpt: PrimitiveNullableEncoder[Byte] with
type ColumnType = ByteOptType
def catalystType = sql.types.ByteType

inline given short: PrimitiveNonNullableEncoder[Short] with
type ColumnType = ShortType
def catalystType = sql.types.ShortType
inline given shortOpt: PrimitiveNullableEncoder[Short] with
type ColumnType = ShortOptType
def catalystType = sql.types.ShortType

inline given int: PrimitiveNonNullableEncoder[Int] with
type ColumnType = IntegerType
def catalystType = sql.types.IntegerType
inline given intOpt: PrimitiveNullableEncoder[Int] with
type ColumnType = IntegerOptType
def catalystType = sql.types.IntegerType

inline given long: PrimitiveNonNullableEncoder[Long] with
type ColumnType = LongType
def catalystType = sql.types.LongType
inline given longOpt: PrimitiveNullableEncoder[Long] with
type ColumnType = LongOptType
def catalystType = sql.types.LongType

inline given float: PrimitiveNonNullableEncoder[Float] with
type ColumnType = FloatType
def catalystType = sql.types.FloatType
inline given floatOpt: PrimitiveNullableEncoder[Float] with
type ColumnType = FloatOptType
def catalystType = sql.types.FloatType

inline given double: PrimitiveNonNullableEncoder[Double] with
type ColumnType = DoubleType
def catalystType = sql.types.DoubleType
inline given doubleOpt: PrimitiveNullableEncoder[Double] with
type ColumnType = DoubleOptType
def catalystType = sql.types.DoubleType

export StructEncoder.{fromMirror, optFromMirror}

trait StructEncoder[-A] extends Encoder[A]:
type StructSchema <: Tuple
type ColumnType = StructType[StructSchema]
override def catalystType: sql.types.StructType
override def encode(a: A): sql.Row

object StructEncoder:
type Aux[-A, E <: Tuple] = StructEncoder[A] { type StructSchema = E }

private case class ColumnInfo(
labelType: Type[?],
labelValue: String,
decodedType: Type[?],
encoder: Expr[Encoder[?]]
)

private def getColumnSchemaType(using quotes: Quotes)(subcolumnInfos: List[ColumnInfo]): Type[?] =
subcolumnInfos match
case Nil => Type.of[EmptyTuple]
case info :: tail =>
info.labelType match
case '[Name.Subtype[lt]] =>
info.encoder match
case '{ ${encoder}: Encoder.Aux[tpe, DataType.Subtype[e]] } =>
getColumnSchemaType(tail) match
case '[TupleSubtype[tailType]] =>
Type.of[(lt := e) *: tailType]

private def getSubcolumnInfos(elemLabels: Type[?], elemTypes: Type[?])(using Quotes): List[ColumnInfo] =
import quotes.reflect.Select
elemLabels match
case '[EmptyTuple] => Nil
case '[label *: labels] =>
val labelValue = Type.valueOfConstant[label].get.toString
elemTypes match
case '[tpe *: types] =>
Expr.summon[Encoder[tpe]] match
case Some(encoderExpr) =>
ColumnInfo(Type.of[label], labelValue, Type.of[tpe], encoderExpr) :: getSubcolumnInfos(Type.of[labels], Type.of[types])
case _ => quotes.reflect.report.errorAndAbort(s"Could not summon encoder for ${Type.show[tpe]}")

transparent inline given fromMirror[A]: StructEncoder[A] = ${ fromMirrorImpl[A] }

def fromMirrorImpl[A : Type](using q: Quotes): Expr[StructEncoder[A]] =
Expr.summon[Mirror.Of[A]].getOrElse(throw new Exception(s"Could not find Mirror when generating encoder for ${Type.show[A]}")) match
case '{ ${mirror}: Mirror.ProductOf[A] { type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes } } =>
val subcolumnInfos = getSubcolumnInfos(Type.of[elementLabels], Type.of[elementTypes])
val columnSchemaType = getColumnSchemaType(subcolumnInfos)

val structFieldExprs = subcolumnInfos.map { info =>
'{ sql.types.StructField(${Expr(info.labelValue)}, ${info.encoder}.catalystType, ${info.encoder}.isNullable) }
}
val structFields = Expr.ofSeq(structFieldExprs)

def rowElements(value: Expr[A]) =
subcolumnInfos.map { info =>
import quotes.reflect.*
info.decodedType match
case '[t] =>
'{ ${info.encoder}.asInstanceOf[Encoder[t]].encode(${ Select.unique(value.asTerm, info.labelValue).asExprOf[t] }) }
}

def rowElementsTuple(row: Expr[sql.Row]): Expr[Tuple] =
val elements = subcolumnInfos.zipWithIndex.map { (info, idx) =>
given Quotes = q
'{ ${info.encoder}.decode(${row}.get(${Expr(idx)})) }
}
Expr.ofTupleFromSeq(elements)

columnSchemaType match
case '[TupleSubtype[t]] =>
'{
(new StructEncoder[A] {
override type StructSchema = t
override def catalystType = sql.types.StructType(${ structFields })
override def isNullable = false
override def encode(a: A) =
sql.Row.fromSeq(${ Expr.ofSeq(rowElements('a)) })
override def decode(a: Any) =
${mirror}.fromProduct(${ rowElementsTuple('{a.asInstanceOf[sql.Row]}) })
}): StructEncoder[A] { type StructSchema = t }
}
end fromMirrorImpl

inline given optFromMirror[A](using encoder: StructEncoder[A]): (Encoder[Option[A]] { type ColumnType = StructOptType[encoder.StructSchema] }) =
new Encoder[Option[A]]:
override type ColumnType = StructOptType[encoder.StructSchema]
override def encode(value: Option[A]): Any = value.map(encoder.encode).orNull
override def decode(value: Any): Any = Option(encoder.decode)
override def catalystType = encoder.catalystType
override def isNullable = true

import DataType.NotNull

sealed class BooleanOptType extends DataType
Expand Down
Loading

0 comments on commit a6ff6a9

Please sign in to comment.