Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tree-string color diff #176

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package com.github.mrpowers.spark.fast.tests
import com.github.mrpowers.spark.fast.tests.DatasetComparer.maxUnequalRowsToShow
import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.reflect.ClassTag

Expand Down Expand Up @@ -49,7 +49,7 @@ Expected DataFrame Row Count: '$expectedCount'
truncate: Int = 500,
equals: (T, T) => Boolean = (o1: T, o2: T) => o1.equals(o2)
): Unit = {
SchemaComparer.assertSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS
assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals)
}
Expand Down Expand Up @@ -103,7 +103,7 @@ Expected DataFrame Row Count: '$expectedCount'
ignoreMetadata: Boolean = true
): Unit = {
// first check if the schemas are equal
SchemaComparer.assertSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS
assertLargeDatasetContentEquality(actual, expectedDS, equals, orderedComparison)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,134 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff
import com.github.mrpowers.spark.fast.tests.SchemaDiffOutputFormat.SchemaDiffOutputFormat
import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, NullType, StructField, StructType}
import org.apache.spark.sql.types._

object SchemaComparer {
private val INDENT_GAP = 5
private val DESCRIPTION_GAP = 21
private val TREE_GAP = 6
case class DatasetSchemaMismatch(smth: String) extends Exception(smth)
private def betterSchemaMismatchMessage[T](actualDS: Dataset[T], expectedDS: Dataset[T]): String = {
private def betterSchemaMismatchMessage(actualSchema: StructType, expectedSchema: StructType): String = {
showProductDiff(
("Actual Schema", "Expected Schema"),
actualDS.schema.fields,
expectedDS.schema.fields,
actualSchema.fields,
expectedSchema.fields,
truncate = 200
)
}

def assertSchemaEqual[T](
private def treeSchemaMismatchMessage[T](actualSchema: StructType, expectedSchema: StructType): String = {
def flattenStrucType(s: StructType, indent: Int): (Seq[(Int, StructField)], Int) = s
.foldLeft((Seq.empty[(Int, StructField)], Int.MinValue)) { case ((fieldPair, maxWidth), f) =>
val gap = indent * INDENT_GAP + DESCRIPTION_GAP + f.name.length + f.dataType.typeName.length + f.nullable.toString.length
val pair = fieldPair :+ (indent, f)
val newMaxWidth = scala.math.max(maxWidth, gap)
f.dataType match {
case st: StructType =>
val (flattenPair, width) = flattenStrucType(st, indent + 1)
(pair ++ flattenPair, scala.math.max(newMaxWidth, width))
case _ => (pair, newMaxWidth)
}
}

def depthToIndentStr(depth: Int): String = Range(0, depth).map(_ => "| ").mkString + "|--"
val (treeFieldPair1, tree1MaxWidth) = flattenStrucType(actualSchema, 0)
val (treeFieldPair2, _) = flattenStrucType(expectedSchema, 0)
val (treePair, maxWidth) = treeFieldPair1
.zipAll(treeFieldPair2, (0, null), (0, null))
.foldLeft((Seq.empty[(String, String)], 0)) { case ((acc, maxWidth), ((indent1, field1), (indent2, field2))) =>
val prefix1 = depthToIndentStr(indent1)
val prefix2 = depthToIndentStr(indent2)
val (sprefix1, sprefix2) = if (indent1 != indent2) {
(Red(prefix1), Green(prefix2))
} else {
(DarkGray(prefix1), DarkGray(prefix2))
}

val pair = if (field1 != null && field2 != null) {
val (name1, name2) =
if (field1.name != field2.name)
(Red(field1.name), Green(field2.name))
else
(DarkGray(field1.name), DarkGray(field2.name))

val (dtype1, dtype2) =
if (field1.dataType != field2.dataType)
(Red(field1.dataType.typeName), Green(field2.dataType.typeName))
else
(DarkGray(field1.dataType.typeName), DarkGray(field2.dataType.typeName))

val (nullable1, nullable2) =
if (field1.nullable != field2.nullable)
(Red(field1.nullable.toString), Green(field2.nullable.toString))
else
(DarkGray(field1.nullable.toString), DarkGray(field2.nullable.toString))

val structString1 = s"$sprefix1 $name1 : $dtype1 (nullable = $nullable1)"
val structString2 = s"$sprefix2 $name2 : $dtype2 (nullable = $nullable2)"
(structString1, structString2)
} else {
val structString1 = if (field1 != null) {
val name = Red(field1.name)
val dtype = Red(field1.dataType.typeName)
val nullable = Red(field1.nullable.toString)
s"$sprefix1 $name : $dtype (nullable = $nullable)"
} else ""

val structString2 = if (field2 != null) {
val name = Green(field2.name)
val dtype = Green(field2.dataType.typeName)
val nullable = Green(field2.nullable.toString)
s"$sprefix2 $name : $dtype (nullable = $nullable)"
} else ""
(structString1, structString2)
}
(acc :+ pair, math.max(maxWidth, pair._1.length))
}

val schemaGap = maxWidth + TREE_GAP
val headerGap = tree1MaxWidth + TREE_GAP
treePair
.foldLeft(new StringBuilder("\nActual Schema".padTo(headerGap, ' ') + "Expected Schema\n")) { case (sb, (s1, s2)) =>
val gap = if (s1.isEmpty) headerGap else schemaGap
val s = if (s2.isEmpty) s1 else s1.padTo(gap, ' ')
sb.append(s + s2 + "\n")
}
.toString()
}

def assertDatasetSchemaEqual[T](
actualDS: Dataset[T],
expectedDS: Dataset[T],
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
ignoreColumnOrder: Boolean = true,
ignoreMetadata: Boolean = true
ignoreMetadata: Boolean = true,
outputFormat: SchemaDiffOutputFormat = SchemaDiffOutputFormat.Table
): Unit = {
assertSchemaEqual(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata, outputFormat)
}

def assertSchemaEqual(
actualSchema: StructType,
expectedSchema: StructType,
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
ignoreColumnOrder: Boolean = true,
ignoreMetadata: Boolean = true,
outputFormat: SchemaDiffOutputFormat = SchemaDiffOutputFormat.Table
): Unit = {
require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.")
if (!SchemaComparer.equals(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)) {
throw DatasetSchemaMismatch(
"Diffs\n" + betterSchemaMismatchMessage(actualDS, expectedDS)
)
if (!SchemaComparer.equals(actualSchema, expectedSchema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)) {
val diffString = outputFormat match {
case SchemaDiffOutputFormat.Tree => treeSchemaMismatchMessage(actualSchema, expectedSchema)
case SchemaDiffOutputFormat.Table => betterSchemaMismatchMessage(actualSchema, expectedSchema)
}

throw DatasetSchemaMismatch(s"Diffs\n$diffString")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.github.mrpowers.spark.fast.tests

object SchemaDiffOutputFormat extends Enumeration {
type SchemaDiffOutputFormat = Value

val Tree, Table = Value
}
Loading