Skip to content

Commit

Permalink
Merge pull request #36 from VirtusLab/new_named_columns
Browse files Browse the repository at this point in the history
Redesign handling of named columns

* Separate general Column supertype from data type specific Col[T]
* Remove column names from members of view-like refinements
* Rely on tuples instead of varargs in user facing APIs of methods like select, agg, groupBy
* Assign names to columns via implicit conversions
  • Loading branch information
prolativ authored Jul 8, 2024
2 parents dd486a8 + fd2b51e commit acb3165
Show file tree
Hide file tree
Showing 19 changed files with 314 additions and 205 deletions.
35 changes: 35 additions & 0 deletions src/main/CollectColumns.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.virtuslab.iskra

import scala.compiletime.error

import org.virtuslab.iskra.types.DataType

// TODO should it be covariant or not?
trait CollectColumns[-C]:
type CollectedColumns <: Tuple
def underlyingColumns(c: C): Seq[UntypedColumn]

// Using `given ... with { ... }` syntax might sometimes break pattern match on `CollectColumns[...] { type CollectedColumns = cc }`

object CollectColumns:
given collectNamedColumn[N <: Name, T <: DataType]: CollectColumns[NamedColumn[N, T]] with
type CollectedColumns = (N := T) *: EmptyTuple
def underlyingColumns(c: NamedColumn[N, T]) = Seq(c.untyped)

given collectColumnsWithSchema[S <: Tuple]: CollectColumns[ColumnsWithSchema[S]] with
type CollectedColumns = S
def underlyingColumns(c: ColumnsWithSchema[S]) = c.underlyingColumns

given collectEmptyTuple[S]: CollectColumns[EmptyTuple] with
type CollectedColumns = EmptyTuple
def underlyingColumns(c: EmptyTuple) = Seq.empty

given collectCons[H, T <: Tuple](using collectHead: CollectColumns[H], collectTail: CollectColumns[T]): (CollectColumns[H *: T] { type CollectedColumns = Tuple.Concat[collectHead.CollectedColumns, collectTail.CollectedColumns] }) =
new CollectColumns[H *: T]:
type CollectedColumns = Tuple.Concat[collectHead.CollectedColumns, collectTail.CollectedColumns]
def underlyingColumns(c: H *: T) = collectHead.underlyingColumns(c.head) ++ collectTail.underlyingColumns(c.tail)


// TODO Customize error message for different operations with an explanation
class CannotCollectColumns(typeName: String)
extends Exception(s"Could not find an instance of CollectColumns for ${typeName}")
119 changes: 70 additions & 49 deletions src/main/Column.scala
Original file line number Diff line number Diff line change
@@ -1,73 +1,94 @@
package org.virtuslab.iskra

import scala.language.implicitConversions

import scala.quoted.*

import org.apache.spark.sql.{Column => UntypedColumn}
import types.DataType
import MacroHelpers.TupleSubtype

class Column(val untyped: UntypedColumn):
inline def name(using v: ValueOf[Name]): Name = v.value

object Column:
implicit transparent inline def columnToNamedColumn(inline col: Col[?]): NamedColumn[?, ?] =
${ columnToNamedColumnImpl('col) }

private def columnToNamedColumnImpl(col: Expr[Col[?]])(using Quotes): Expr[NamedColumn[?, ?]] =
import quotes.reflect.*
col match
case '{ ($v: StructuralSchemaView).selectDynamic($nm: Name).$asInstanceOf$[Col[tp]] } =>
nm.asTerm.tpe.asType match
case '[Name.Subtype[n]] =>
'{ NamedColumn[n, tp](${ col }.untyped.as(${ nm })) }
case '{ $c: Col[tp] } =>
col.asTerm match
case Inlined(_, _, Ident(name)) =>
ConstantType(StringConstant(name)).asType match
case '[Name.Subtype[n]] =>
val alias = Literal(StringConstant(name)).asExprOf[Name]
'{ NamedColumn[n, tp](${ col }.untyped.as(${ alias })) }

extension [T <: DataType](col: Col[T])
inline def as[N <: Name](name: N): NamedColumn[N, T] =
NamedColumn[N, T](col.untyped.as(name))
inline def alias[N <: Name](name: N): NamedColumn[N, T] =
NamedColumn[N, T](col.untyped.as(name))

extension [T1 <: DataType](col1: Col[T1])
inline def +[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Plus[T1, T2]): Col[op.Out] = op(col1, col2)
inline def -[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Minus[T1, T2]): Col[op.Out] = op(col1, col2)
inline def *[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Mult[T1, T2]): Col[op.Out] = op(col1, col2)
inline def /[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Div[T1, T2]): Col[op.Out] = op(col1, col2)
inline def ++[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.PlusPlus[T1, T2]): Col[op.Out] = op(col1, col2)
inline def <[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Lt[T1, T2]): Col[op.Out] = op(col1, col2)
inline def <=[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Le[T1, T2]): Col[op.Out] = op(col1, col2)
inline def >[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Gt[T1, T2]): Col[op.Out] = op(col1, col2)
inline def >=[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Ge[T1, T2]): Col[op.Out] = op(col1, col2)
inline def ===[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Eq[T1, T2]): Col[op.Out] = op(col1, col2)
inline def =!=[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Ne[T1, T2]): Col[op.Out] = op(col1, col2)
inline def &&[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.And[T1, T2]): Col[op.Out] = op(col1, col2)
inline def ||[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Or[T1, T2]): Col[op.Out] = op(col1, col2)


class Col[+T <: DataType](untyped: UntypedColumn) extends Column(untyped)

sealed trait NamedColumns[Schema](val underlyingColumns: Seq[UntypedColumn])

object Columns:
transparent inline def apply(inline columns: NamedColumns[?]*): NamedColumns[?] = ${ applyImpl('columns) }
transparent inline def apply[C <: NamedColumns](columns: C): ColumnsWithSchema[?] = ${ applyImpl('columns) }

private def applyImpl(columns: Expr[Seq[NamedColumns[?]]])(using Quotes): Expr[NamedColumns[?]] =
private def applyImpl[C : Type](columns: Expr[C])(using Quotes): Expr[ColumnsWithSchema[?]] =
import quotes.reflect.*

val columnValuesWithTypes = columns match
case Varargs(colExprs) =>
colExprs.map { arg =>
arg match
case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema])
}
Expr.summon[CollectColumns[C]] match
case Some(collectColumns) =>
collectColumns match
case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } =>
Type.of[collectedColumns] match
case '[TupleSubtype[collectedCols]] =>
'{
val cols = ${ cc }.underlyingColumns(${ columns })
ColumnsWithSchema[collectedCols](cols)
}
case None =>
throw CollectColumns.CannotCollectColumns(Type.show[C])

val columnsValues = columnValuesWithTypes.map(_._1)
val columnsTypes = columnValuesWithTypes.map(_._2)

val schemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes)
trait NamedColumnOrColumnsLike

schemaTpe match
case '[s] =>
'{
val cols = ${ Expr.ofSeq(columnsValues) }.flatten
new NamedColumns[s](cols) {}
}
type NamedColumns = Repeated[NamedColumnOrColumnsLike]

class Column[+T <: DataType](val untyped: UntypedColumn):
class NamedColumn[N <: Name, T <: DataType](val untyped: UntypedColumn)
extends NamedColumnOrColumnsLike

inline def name(using v: ValueOf[Name]): Name = v.value
class ColumnsWithSchema[Schema <: Tuple](val underlyingColumns: Seq[UntypedColumn]) extends NamedColumnOrColumnsLike

object Column:
extension [T <: DataType](col: Column[T])
inline def as[N <: Name](name: N)(using v: ValueOf[N]): LabeledColumn[N, T] =
LabeledColumn[N, T](col.untyped.as(v.value))
inline def alias[N <: Name](name: N)(using v: ValueOf[N]): LabeledColumn[N, T] =
LabeledColumn[N, T](col.untyped.as(v.value))

extension [T1 <: DataType](col1: Column[T1])
inline def +[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Plus[T1, T2]): Column[op.Out] = op(col1, col2)
inline def -[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Minus[T1, T2]): Column[op.Out] = op(col1, col2)
inline def *[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Mult[T1, T2]): Column[op.Out] = op(col1, col2)
inline def /[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Div[T1, T2]): Column[op.Out] = op(col1, col2)
inline def ++[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.PlusPlus[T1, T2]): Column[op.Out] = op(col1, col2)
inline def <[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Lt[T1, T2]): Column[op.Out] = op(col1, col2)
inline def <=[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Le[T1, T2]): Column[op.Out] = op(col1, col2)
inline def >[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Gt[T1, T2]): Column[op.Out] = op(col1, col2)
inline def >=[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Ge[T1, T2]): Column[op.Out] = op(col1, col2)
inline def ===[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Eq[T1, T2]): Column[op.Out] = op(col1, col2)
inline def =!=[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Ne[T1, T2]): Column[op.Out] = op(col1, col2)
inline def &&[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.And[T1, T2]): Column[op.Out] = op(col1, col2)
inline def ||[T2 <: DataType](col2: Column[T2])(using op: ColumnOp.Or[T1, T2]): Column[op.Out] = op(col1, col2)

@annotation.showAsInfix
class :=[L <: LabeledColumn.Label, T <: DataType](untyped: UntypedColumn)
extends Column[T](untyped)
with NamedColumns[(L := T) *: EmptyTuple](Seq(untyped))
trait :=[L <: ColumnLabel, T <: DataType]

@annotation.showAsInfix
trait /[+Prefix <: Name, +Suffix <: Name]

type LabeledColumn[L <: LabeledColumn.Label, T <: DataType] = :=[L, T]

object LabeledColumn:
type Label = Name | (Name / Name)
def apply[L <: LabeledColumn.Label, T <: DataType](untyped: UntypedColumn) = new :=[L, T](untyped)
type ColumnLabel = Name | (Name / Name)
2 changes: 1 addition & 1 deletion src/main/ColumnOp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.virtuslab.iskra
import scala.quoted.*
import org.apache.spark.sql
import org.apache.spark.sql.functions.concat
import org.virtuslab.iskra.{Column as Col}
import org.virtuslab.iskra.Col
import org.virtuslab.iskra.UntypedOps.typed
import org.virtuslab.iskra.types.*
import DataType.*
Expand Down
6 changes: 3 additions & 3 deletions src/main/FrameSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ object FrameSchema:
case TupleSubtype[s2] => S1 *: s2
case _ => S1 *: S2 *: EmptyTuple

type NullableLabeledColumn[T] = T match
type NullableLabeledDataType[T] = T match
case label := tpe => label := DataType.Nullable[tpe]

type NullableSchema[T] = T match
case TupleSubtype[s] => Tuple.Map[s, NullableLabeledColumn]
case _ => NullableLabeledColumn[T]
case TupleSubtype[s] => Tuple.Map[s, NullableLabeledDataType]
case _ => NullableLabeledDataType[T]

def reownType[Owner <: Name : Type](schema: Type[?])(using Quotes): Type[?] =
schema match
Expand Down
96 changes: 40 additions & 56 deletions src/main/Grouping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,38 @@ object GroupBy:

given groupByOps: {} with
extension [View <: SchemaView](groupBy: GroupBy[View])
transparent inline def apply(inline groupingColumns: View ?=> NamedColumns[?]*) = ${ applyImpl[View]('groupBy, 'groupingColumns) }
transparent inline def apply[C <: NamedColumns](groupingColumns: View ?=> C) = ${ applyImpl[View, C]('groupBy, 'groupingColumns) }

def groupByImpl[S : Type](df: Expr[StructDataFrame[S]])(using Quotes): Expr[GroupBy[?]] =
private def groupByImpl[S : Type](df: Expr[StructDataFrame[S]])(using Quotes): Expr[GroupBy[?]] =
import quotes.reflect.asTerm
val viewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[S]]
viewExpr.asTerm.tpe.asType match
case '[SchemaView.Subtype[v]] =>
'{ GroupBy[v](${ viewExpr }.asInstanceOf[v], ${ df }.untyped) }

def applyImpl[View <: SchemaView : Type](groupBy: Expr[GroupBy[View]], groupingColumns: Expr[Seq[View ?=> NamedColumns[?]]])(using Quotes): Expr[GroupedDataFrame[View]] =
private def applyImpl[View <: SchemaView : Type, C : Type](groupBy: Expr[GroupBy[View]], groupingColumns: Expr[View ?=> C])(using Quotes): Expr[GroupedDataFrame[View]] =
import quotes.reflect.*

val columnValuesWithTypes = groupingColumns match
case Varargs(colExprs) =>
colExprs.map { arg =>
val reduced = Term.betaReduce('{$arg(using ${ groupBy }.view)}.asTerm).get
reduced.asExpr match
case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema])
}

val columnsValues = columnValuesWithTypes.map(_._1)
val columnsTypes = columnValuesWithTypes.map(_._2)

val groupedSchemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes)
groupedSchemaTpe match
case '[TupleSubtype[groupingKeys]] =>
val groupedViewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[groupingKeys]]

groupedViewExpr.asTerm.tpe.asType match
case '[SchemaView.Subtype[groupedView]] =>
'{
val groupingCols = ${ Expr.ofSeq(columnsValues) }.flatten
new GroupedDataFrame[View]:
type GroupingKeys = groupingKeys
type GroupedView = groupedView
def underlying = ${ groupBy }.underlying.groupBy(groupingCols*)
def fullView = ${ groupBy }.view
def groupedView = ${ groupedViewExpr }.asInstanceOf[GroupedView]
}
Expr.summon[CollectColumns[C]] match
case Some(collectColumns) =>
collectColumns match
case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } =>
Type.of[collectedColumns] match
case '[TupleSubtype[collectedCols]] =>
val groupedViewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[collectedCols]]
groupedViewExpr.asTerm.tpe.asType match
case '[SchemaView.Subtype[groupedView]] =>
'{
val groupingCols = ${ cc }.underlyingColumns(${ groupingColumns }(using ${ groupBy }.view))
new GroupedDataFrame[View]:
type GroupingKeys = collectedCols
type GroupedView = groupedView
def underlying = ${ groupBy }.underlying.groupBy(groupingCols*)
def fullView = ${ groupBy }.view
def groupedView = ${ groupedViewExpr }.asInstanceOf[GroupedView]
}
case None =>
throw CollectColumns.CannotCollectColumns(Type.show[C])

// TODO: Rename to RelationalGroupedDataset and handle other aggregations: cube, rollup (and pivot?)
trait GroupedDataFrame[FullView <: SchemaView]:
Expand All @@ -66,13 +59,12 @@ trait GroupedDataFrame[FullView <: SchemaView]:
object GroupedDataFrame:
given groupedDataFrameOps: {} with
extension [FullView <: SchemaView, GroupKeys <: Tuple, GroupView <: SchemaView](gdf: GroupedDataFrame[FullView]{ type GroupedView = GroupView; type GroupingKeys = GroupKeys })
transparent inline def agg(inline columns: (Agg { type View = FullView }, GroupView) ?=> NamedColumns[?]*): StructDataFrame[?] =
${ aggImpl[FullView, GroupKeys, GroupView]('gdf, 'columns) }
transparent inline def agg[C <: NamedColumns](columns: (Agg { type View = FullView }, GroupView) ?=> C): StructDataFrame[?] =
${ aggImpl[FullView, GroupKeys, GroupView, C]('gdf, 'columns) }


def aggImpl[FullView <: SchemaView : Type, GroupingKeys <: Tuple : Type, GroupView <: SchemaView : Type](
private def aggImpl[FullView <: SchemaView : Type, GroupingKeys <: Tuple : Type, GroupView <: SchemaView : Type, C : Type](
gdf: Expr[GroupedDataFrame[FullView] { type GroupedView = GroupView }],
columns: Expr[Seq[(Agg { type View = FullView }, GroupView) ?=> NamedColumns[?]]]
columns: Expr[(Agg { type View = FullView }, GroupView) ?=> C]
)(using Quotes): Expr[StructDataFrame[?]] =
import quotes.reflect.*

Expand All @@ -82,27 +74,19 @@ object GroupedDataFrame:
val view = ${ gdf }.fullView
}

val columnValuesWithTypes = columns match
case Varargs(colExprs) =>
colExprs.map { arg =>
val reduced = Term.betaReduce('{$arg(using ${ aggWrapper }, ${ gdf }.groupedView)}.asTerm).get
reduced.asExpr match
case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema])
}

val columnsValues = columnValuesWithTypes.map(_._1)
val columnsTypes = columnValuesWithTypes.map(_._2)

val schemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes)
schemaTpe match
case '[s] =>
'{
// TODO assert cols is not empty
val cols = ${ Expr.ofSeq(columnsValues) }.flatten
StructDataFrame[FrameSchema.Merge[GroupingKeys, s]](
${ gdf }.underlying.agg(cols.head, cols.tail*)
)
}
Expr.summon[CollectColumns[C]] match
case Some(collectColumns) =>
collectColumns match
case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } =>
'{
// TODO assert cols is not empty
val cols = ${ cc }.underlyingColumns(${ columns }(using ${ aggWrapper }, ${ gdf }.groupedView))
StructDataFrame[FrameSchema.Merge[GroupingKeys, collectedColumns]](
${ gdf }.underlying.agg(cols.head, cols.tail*)
)
}
case None =>
throw CollectColumns.CannotCollectColumns(Type.show[C])

trait Agg:
type View <: SchemaView
Expand Down
2 changes: 1 addition & 1 deletion src/main/JoinOnCondition.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ object JoinOnCondition:
import quotes.reflect.*

'{ ${ condition }(using ${ joiningView }) } match
case '{ $cond: Column[BooleanOptType] } =>
case '{ $cond: Col[BooleanOptType] } =>
'{
val joined = ${ join }.left.join(${ join }.right, ${ cond }.untyped, JoinType.typeName[T])
StructDataFrame[JoinedSchema](joined)
Expand Down
25 changes: 25 additions & 0 deletions src/main/Repeated.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.virtuslab.iskra

type Repeated[A] =
A
| (A, A)
| (A, A, A)
| (A, A, A, A)
| (A, A, A, A, A)
| (A, A, A, A, A, A)
| (A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A)
| (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) // 22 is maximal arity
Loading

0 comments on commit acb3165

Please sign in to comment.