Skip to content

Commit

Permalink
Clean up named columns abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
prolativ committed Jul 8, 2024
1 parent f0753a4 commit fd2b51e
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 68 deletions.
28 changes: 14 additions & 14 deletions src/main/CollectColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,34 @@ package org.virtuslab.iskra

import scala.compiletime.error

import org.virtuslab.iskra.types.DataType

// TODO should it be covariant or not?
trait CollectColumns[-C]:
type CollectedColumns <: Tuple
def underlyingColumns(c: C): Seq[UntypedColumn]

// Using `given ... with { ... }` syntax might sometimes break pattern match on `CollectColumns[...] { type CollectedColumns = cc }`

object CollectColumns extends CollectColumnsLowPrio:
given collectSingle[S <: Tuple]: CollectColumns[NamedColumns[S]] with
object CollectColumns:
given collectNamedColumn[N <: Name, T <: DataType]: CollectColumns[NamedColumn[N, T]] with
type CollectedColumns = (N := T) *: EmptyTuple
def underlyingColumns(c: NamedColumn[N, T]) = Seq(c.untyped)

given collectColumnsWithSchema[S <: Tuple]: CollectColumns[ColumnsWithSchema[S]] with
type CollectedColumns = S
def underlyingColumns(c: NamedColumns[S]) = c.underlyingColumns
def underlyingColumns(c: ColumnsWithSchema[S]) = c.underlyingColumns

given collectEmptyTuple[S]: CollectColumns[EmptyTuple] with
type CollectedColumns = EmptyTuple
def underlyingColumns(c: EmptyTuple) = Seq.empty

given collectMultiCons[S <: Tuple, T <: Tuple](using collectTail: CollectColumns[T]): (CollectColumns[NamedColumns[S] *: T] { type CollectedColumns = Tuple.Concat[S, collectTail.CollectedColumns] }) =
new CollectColumns[NamedColumns[S] *: T]:
type CollectedColumns = Tuple.Concat[S, collectTail.CollectedColumns]
def underlyingColumns(c: NamedColumns[S] *: T) = c.head.underlyingColumns ++ collectTail.underlyingColumns(c.tail)
given collectCons[H, T <: Tuple](using collectHead: CollectColumns[H], collectTail: CollectColumns[T]): (CollectColumns[H *: T] { type CollectedColumns = Tuple.Concat[collectHead.CollectedColumns, collectTail.CollectedColumns] }) =
new CollectColumns[H *: T]:
type CollectedColumns = Tuple.Concat[collectHead.CollectedColumns, collectTail.CollectedColumns]
def underlyingColumns(c: H *: T) = collectHead.underlyingColumns(c.head) ++ collectTail.underlyingColumns(c.tail)


// TODO Customize error message for different operations with an explanation
class CannotCollectColumns(typeName: String)
extends Exception(s"Could not find an instance of CollectColumns for ${typeName}")


trait CollectColumnsLowPrio:
given collectSingleCons[S, T <: Tuple](using collectTail: CollectColumns[T]): (CollectColumns[NamedColumns[S] *: T] { type CollectedColumns = S *: collectTail.CollectedColumns}) =
new CollectColumns[NamedColumns[S] *: T]:
type CollectedColumns = S *: collectTail.CollectedColumns
def underlyingColumns(c: NamedColumns[S] *: T) = c.head.underlyingColumns ++ collectTail.underlyingColumns(c.tail)
86 changes: 44 additions & 42 deletions src/main/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,61 +6,35 @@ import scala.quoted.*

import org.apache.spark.sql.{Column => UntypedColumn}
import types.DataType

sealed trait NamedColumns[Schema](val underlyingColumns: Seq[UntypedColumn])

object Columns:
transparent inline def apply(inline columns: NamedColumns[?]*): NamedColumns[?] = ${ applyImpl('columns) }

private def applyImpl(columns: Expr[Seq[NamedColumns[?]]])(using Quotes): Expr[NamedColumns[?]] =
import quotes.reflect.*

val columnValuesWithTypes = columns match
case Varargs(colExprs) =>
colExprs.map { arg =>
arg match
case '{ $value: NamedColumns[schema] } => ('{ ${ value }.underlyingColumns }, Type.of[schema])
}

val columnsValues = columnValuesWithTypes.map(_._1)
val columnsTypes = columnValuesWithTypes.map(_._2)

val schemaTpe = FrameSchema.schemaTypeFromColumnsTypes(columnsTypes)

schemaTpe match
case '[s] =>
'{
val cols = ${ Expr.ofSeq(columnsValues) }.flatten
new NamedColumns[s](cols) {}
}
import MacroHelpers.TupleSubtype

class Column(val untyped: UntypedColumn):
inline def name(using v: ValueOf[Name]): Name = v.value

object Column:
implicit transparent inline def columnToLabeledColumn(inline col: Col[?]): LabeledColumn[?, ?] =
${ columnToLabeledColumnImpl('col) }
implicit transparent inline def columnToNamedColumn(inline col: Col[?]): NamedColumn[?, ?] =
${ columnToNamedColumnImpl('col) }

private def columnToLabeledColumnImpl(col: Expr[Col[?]])(using Quotes): Expr[LabeledColumn[?, ?]] =
private def columnToNamedColumnImpl(col: Expr[Col[?]])(using Quotes): Expr[NamedColumn[?, ?]] =
import quotes.reflect.*
col match
case '{ ($v: StructuralSchemaView).selectDynamic($nm: Name).$asInstanceOf$[Col[tp]] } =>
nm.asTerm.tpe.asType match
case '[Name.Subtype[n]] =>
'{ LabeledColumn[n, tp](${ col }.untyped.as(${ nm })) }
'{ NamedColumn[n, tp](${ col }.untyped.as(${ nm })) }
case '{ $c: Col[tp] } =>
col.asTerm match
case Inlined(_, _, Ident(name)) =>
ConstantType(StringConstant(name)).asType match
case '[Name.Subtype[n]] =>
val alias = Literal(StringConstant(name)).asExprOf[Name]
'{ LabeledColumn[n, tp](${ col }.untyped.as(${ alias })) }
'{ NamedColumn[n, tp](${ col }.untyped.as(${ alias })) }

extension [T <: DataType](col: Col[T])
inline def as[N <: Name](name: N): LabeledColumn[N, T] =
LabeledColumn[N, T](col.untyped.as(name))
inline def alias[N <: Name](name: N): LabeledColumn[N, T] =
LabeledColumn[N, T](col.untyped.as(name))
inline def as[N <: Name](name: N): NamedColumn[N, T] =
NamedColumn[N, T](col.untyped.as(name))
inline def alias[N <: Name](name: N): NamedColumn[N, T] =
NamedColumn[N, T](col.untyped.as(name))

extension [T1 <: DataType](col1: Col[T1])
inline def +[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Plus[T1, T2]): Col[op.Out] = op(col1, col2)
Expand All @@ -77,16 +51,44 @@ object Column:
inline def &&[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.And[T1, T2]): Col[op.Out] = op(col1, col2)
inline def ||[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Or[T1, T2]): Col[op.Out] = op(col1, col2)


class Col[+T <: DataType](untyped: UntypedColumn) extends Column(untyped)


object Columns:
transparent inline def apply[C <: NamedColumns](columns: C): ColumnsWithSchema[?] = ${ applyImpl('columns) }

private def applyImpl[C : Type](columns: Expr[C])(using Quotes): Expr[ColumnsWithSchema[?]] =
import quotes.reflect.*

Expr.summon[CollectColumns[C]] match
case Some(collectColumns) =>
collectColumns match
case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } =>
Type.of[collectedColumns] match
case '[TupleSubtype[collectedCols]] =>
'{
val cols = ${ cc }.underlyingColumns(${ columns })
ColumnsWithSchema[collectedCols](cols)
}
case None =>
throw CollectColumns.CannotCollectColumns(Type.show[C])


trait NamedColumnOrColumnsLike

type NamedColumns = Repeated[NamedColumnOrColumnsLike]

class NamedColumn[N <: Name, T <: DataType](val untyped: UntypedColumn)
extends NamedColumnOrColumnsLike

class ColumnsWithSchema[Schema <: Tuple](val underlyingColumns: Seq[UntypedColumn]) extends NamedColumnOrColumnsLike


@annotation.showAsInfix
trait :=[L <: LabeledColumn.Label, T <: DataType]
trait :=[L <: ColumnLabel, T <: DataType]

@annotation.showAsInfix
trait /[+Prefix <: Name, +Suffix <: Name]

class LabeledColumn[L <: Name, T <: DataType](untyped: UntypedColumn)
extends NamedColumns[(L := T) *: EmptyTuple](Seq(untyped))

object LabeledColumn:
type Label = Name | (Name / Name)
type ColumnLabel = Name | (Name / Name)
6 changes: 3 additions & 3 deletions src/main/FrameSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ object FrameSchema:
case TupleSubtype[s2] => S1 *: s2
case _ => S1 *: S2 *: EmptyTuple

type NullableLabeledColumn[T] = T match
type NullableLabeledDataType[T] = T match
case label := tpe => label := DataType.Nullable[tpe]

type NullableSchema[T] = T match
case TupleSubtype[s] => Tuple.Map[s, NullableLabeledColumn]
case _ => NullableLabeledColumn[T]
case TupleSubtype[s] => Tuple.Map[s, NullableLabeledDataType]
case _ => NullableLabeledDataType[T]

def reownType[Owner <: Name : Type](schema: Type[?])(using Quotes): Type[?] =
schema match
Expand Down
4 changes: 2 additions & 2 deletions src/main/Grouping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ object GroupBy:

given groupByOps: {} with
extension [View <: SchemaView](groupBy: GroupBy[View])
transparent inline def apply[C <: Repeated[NamedColumns[?]]](groupingColumns: View ?=> C) = ${ applyImpl[View, C]('groupBy, 'groupingColumns) }
transparent inline def apply[C <: NamedColumns](groupingColumns: View ?=> C) = ${ applyImpl[View, C]('groupBy, 'groupingColumns) }

private def groupByImpl[S : Type](df: Expr[StructDataFrame[S]])(using Quotes): Expr[GroupBy[?]] =
import quotes.reflect.asTerm
Expand Down Expand Up @@ -59,7 +59,7 @@ trait GroupedDataFrame[FullView <: SchemaView]:
object GroupedDataFrame:
given groupedDataFrameOps: {} with
extension [FullView <: SchemaView, GroupKeys <: Tuple, GroupView <: SchemaView](gdf: GroupedDataFrame[FullView]{ type GroupedView = GroupView; type GroupingKeys = GroupKeys })
transparent inline def agg[C <: Repeated[NamedColumns[?]]](columns: (Agg { type View = FullView }, GroupView) ?=> C): StructDataFrame[?] =
transparent inline def agg[C <: NamedColumns](columns: (Agg { type View = FullView }, GroupView) ?=> C): StructDataFrame[?] =
${ aggImpl[FullView, GroupKeys, GroupView, C]('gdf, 'columns) }

private def aggImpl[FullView <: SchemaView : Type, GroupingKeys <: Tuple : Type, GroupView <: SchemaView : Type, C : Type](
Expand Down
5 changes: 0 additions & 5 deletions src/main/SchemaView.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ trait StructSchemaView extends StructuralSchemaView:
// TODO: What should be the semantics of `*`? How to handle ambiguous columns?
// type AllColumns <: Tuple
// def * : AllColumns

// def selectDynamic(name: String): AliasedSchemaView | LabeledColumn[?, ?] =
// if frameAliases.contains(name)
// then AliasedSchemaView(name)
// else LabeledColumn(col(Name.escape(name)))

override def selectDynamic(name: String): AliasedSchemaView | Column =
if frameAliases.contains(name)
Expand Down
2 changes: 1 addition & 1 deletion src/main/Select.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object Select:

given selectOps: {} with
extension [View <: SchemaView](select: Select[View])
transparent inline def apply[C <: Repeated[NamedColumns[?]]](columns: View ?=> C): StructDataFrame[?] =
transparent inline def apply[C <: NamedColumns](columns: View ?=> C): StructDataFrame[?] =
${ applyImpl[View, C]('select, 'columns) }

private def applyImpl[View <: SchemaView : Type, C : Type](using Quotes)(select: Expr[Select[View]], columns: Expr[View ?=> C]) =
Expand Down
2 changes: 1 addition & 1 deletion src/main/WithColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object WithColumns:

given withColumnsApply: {} with
extension [Schema <: Tuple, View <: SchemaView](withColumns: WithColumns[Schema, View])
transparent inline def apply[C <: Repeated[NamedColumns[?]]](columns: View ?=> C): StructDataFrame[?] =
transparent inline def apply[C <: NamedColumns](columns: View ?=> C): StructDataFrame[?] =
${ applyImpl[Schema, View, C]('withColumns, 'columns) }

private def applyImpl[Schema <: Tuple : Type, View <: SchemaView : Type, C : Type](
Expand Down
24 changes: 24 additions & 0 deletions src/test/ColumnsTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.virtuslab.iskra.test

import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.BeforeAndAfterAll
import org.scalatest.matchers.should.Matchers.shouldEqual

class ColumnsTest extends SparkUnitTest:
import org.virtuslab.iskra.api.*

case class Foo(x1: Int, x2: Int, x3: Int, x4: Int)

val foos = Seq(
Foo(1, 2, 3, 4)
).toDF.asStruct

test("plus") {
val result = foos.select {
val cols1 = Columns($.x1)
val cols2 = Columns($.x2, $.x3)
(cols1, cols2, $.x4)
}.asClass[Foo].collect().toList

result shouldEqual List(Foo(1, 2, 3, 4))
}

0 comments on commit fd2b51e

Please sign in to comment.