diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataTypesTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataTypesTests.cs
new file mode 100644
index 000000000..919f61cb8
--- /dev/null
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataTypesTests.cs
@@ -0,0 +1,65 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.Spark.Sql;
+using Microsoft.Spark.Sql.Types;
+using Xunit;
+
+namespace Microsoft.Spark.E2ETest.IpcTests
+{
+
+ [Collection("Spark E2E Tests")]
+ public class DataTypesTests
+ {
+ private readonly SparkSession _spark;
+
+ public DataTypesTests(SparkFixture fixture)
+ {
+ _spark = fixture.Spark;
+ }
+
+ ///
+ /// Tests that we can pass a decimal over to Apache Spark and collect it back again, include a check
+ /// for the minimum and maximum decimal that .NET can represent
+ ///
+ [Fact]
+ public void TestDecimalType()
+ {
+ var df = _spark.CreateDataFrame(
+ new List
+ {
+ new GenericRow(
+ new object[]
+ {
+ decimal.MinValue, decimal.MaxValue, decimal.Zero, decimal.MinusOne,
+ new object[]
+ {
+ decimal.MinValue, decimal.MaxValue, decimal.Zero, decimal.MinusOne
+ }
+ }),
+ },
+ new StructType(
+ new List()
+ {
+ new StructField("min", new DecimalType(38, 0)),
+ new StructField("max", new DecimalType(38, 0)),
+ new StructField("zero", new DecimalType(38, 0)),
+ new StructField("minusOne", new DecimalType(38, 0)),
+ new StructField("array", new ArrayType(new DecimalType(38,0)))
+ }));
+
+ Row row = df.Collect().First();
+ Assert.Equal(decimal.MinValue, row[0]);
+ Assert.Equal(decimal.MaxValue, row[1]);
+ Assert.Equal(decimal.Zero, row[2]);
+ Assert.Equal(decimal.MinusOne, row[3]);
+ Assert.Equal(new object[]{decimal.MinValue, decimal.MaxValue, decimal.Zero, decimal.MinusOne},
+ row[4]);
+ }
+
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs
index f7bd145e3..535991b36 100644
--- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs
+++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs
@@ -269,6 +269,9 @@ private object CallJavaMethod(
case 'd':
returnValue = SerDe.ReadDouble(inputStream);
break;
+ case 'm':
+ returnValue = decimal.Parse(SerDe.ReadString(inputStream));
+ break;
case 'b':
returnValue = Convert.ToBoolean(inputStream.ReadByte());
break;
diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs
index 3373bca62..ac9914672 100644
--- a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs
+++ b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs
@@ -32,6 +32,7 @@ internal class PayloadHelper
private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' };
private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' };
private static readonly byte[] s_objectArrTypeId = new[] { (byte)'O' };
+ private static readonly byte[] s_decimalTypeId = new[] { (byte)'m' };
private static readonly ConcurrentDictionary s_isDictionaryTable =
new ConcurrentDictionary();
@@ -109,6 +110,10 @@ internal static void ConvertArgsToBytes(
case TypeCode.Double:
SerDe.Write(destination, (double)arg);
break;
+
+ case TypeCode.Decimal:
+ SerDe.Write(destination, (decimal)arg);
+ break;
case TypeCode.Object:
switch (arg)
@@ -321,7 +326,9 @@ internal static byte[] GetTypeId(Type type)
case TypeCode.Boolean:
return s_boolTypeId;
case TypeCode.Double:
- return s_doubleTypeId;
+ return s_doubleTypeId;
+ case TypeCode.Decimal:
+ return s_decimalTypeId;
case TypeCode.Object:
if (typeof(IJvmObjectReferenceProvider).IsAssignableFrom(type))
{
diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs
index c2c742e87..a36a293a0 100644
--- a/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs
+++ b/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs
@@ -322,6 +322,13 @@ public static void Write(Stream s, long value)
public static void Write(Stream s, double value) =>
Write(s, BitConverter.DoubleToInt64Bits(value));
+ ///
+ /// Writes a decimal to a stream as a string.
+ ///
+ /// The stream to write
+ /// The decimal to write
+ public static void Write(Stream s, decimal value) => Write(s, value.ToString());
+
///
/// Writes a string to a stream.
///
diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
index 44cad97c1..31cf97c12 100644
--- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
+++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
@@ -19,6 +19,7 @@ import scala.collection.JavaConverters._
* This implementation of methods is mostly identical to the SerDe implementation in R.
*/
class SerDe(val tracker: JVMObjectTracker) {
+
def readObjectType(dis: DataInputStream): Char = {
dis.readByte().toChar
}
@@ -35,6 +36,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'g' => new java.lang.Long(readLong(dis))
case 'd' => new java.lang.Double(readDouble(dis))
case 'b' => new java.lang.Boolean(readBoolean(dis))
+ case 'm' => readDecimal(dis)
case 'c' => readString(dis)
case 'e' => readMap(dis)
case 'r' => readBytes(dis)
@@ -59,6 +61,10 @@ class SerDe(val tracker: JVMObjectTracker) {
in.readInt()
}
+ private def readDecimal(in: DataInputStream): BigDecimal = {
+ BigDecimal(readString(in))
+ }
+
private def readLong(in: DataInputStream): Long = {
in.readLong()
}
@@ -110,6 +116,11 @@ class SerDe(val tracker: JVMObjectTracker) {
(0 until len).map(_ => readInt(in)).toArray
}
+ private def readDecimalArr(in: DataInputStream): Array[BigDecimal] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readDecimal(in)).toArray
+ }
+
private def readLongArr(in: DataInputStream): Array[Long] = {
val len = readInt(in)
(0 until len).map(_ => readLong(in)).toArray
@@ -156,6 +167,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'b' => readBooleanArr(dis)
case 'j' => readStringArr(dis).map(x => tracker.getObject(x))
case 'r' => readBytesArr(dis)
+ case 'm' => readDecimalArr(dis)
case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
}
}
@@ -206,6 +218,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case "long" => dos.writeByte('g')
case "integer" => dos.writeByte('i')
case "logical" => dos.writeByte('b')
+ case "bigdecimal" => dos.writeByte('m')
case "date" => dos.writeByte('D')
case "time" => dos.writeByte('t')
case "raw" => dos.writeByte('r')
@@ -238,6 +251,9 @@ class SerDe(val tracker: JVMObjectTracker) {
case "boolean" | "java.lang.Boolean" =>
writeType(dos, "logical")
writeBoolean(dos, value.asInstanceOf[Boolean])
+ case "BigDecimal" | "java.math.BigDecimal" =>
+ writeType(dos, "bigdecimal")
+ writeString(dos, value.toString)
case "java.sql.Date" =>
writeType(dos, "date")
writeDate(dos, value.asInstanceOf[Date])
diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
index a3df3788a..31cf97c12 100644
--- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
+++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
@@ -36,6 +36,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'g' => new java.lang.Long(readLong(dis))
case 'd' => new java.lang.Double(readDouble(dis))
case 'b' => new java.lang.Boolean(readBoolean(dis))
+ case 'm' => readDecimal(dis)
case 'c' => readString(dis)
case 'e' => readMap(dis)
case 'r' => readBytes(dis)
@@ -60,6 +61,10 @@ class SerDe(val tracker: JVMObjectTracker) {
in.readInt()
}
+ private def readDecimal(in: DataInputStream): BigDecimal = {
+ BigDecimal(readString(in))
+ }
+
private def readLong(in: DataInputStream): Long = {
in.readLong()
}
@@ -111,6 +116,11 @@ class SerDe(val tracker: JVMObjectTracker) {
(0 until len).map(_ => readInt(in)).toArray
}
+ private def readDecimalArr(in: DataInputStream): Array[BigDecimal] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readDecimal(in)).toArray
+ }
+
private def readLongArr(in: DataInputStream): Array[Long] = {
val len = readInt(in)
(0 until len).map(_ => readLong(in)).toArray
@@ -157,6 +167,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'b' => readBooleanArr(dis)
case 'j' => readStringArr(dis).map(x => tracker.getObject(x))
case 'r' => readBytesArr(dis)
+ case 'm' => readDecimalArr(dis)
case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
}
}
@@ -207,6 +218,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case "long" => dos.writeByte('g')
case "integer" => dos.writeByte('i')
case "logical" => dos.writeByte('b')
+ case "bigdecimal" => dos.writeByte('m')
case "date" => dos.writeByte('D')
case "time" => dos.writeByte('t')
case "raw" => dos.writeByte('r')
@@ -239,6 +251,9 @@ class SerDe(val tracker: JVMObjectTracker) {
case "boolean" | "java.lang.Boolean" =>
writeType(dos, "logical")
writeBoolean(dos, value.asInstanceOf[Boolean])
+ case "BigDecimal" | "java.math.BigDecimal" =>
+ writeType(dos, "bigdecimal")
+ writeString(dos, value.toString)
case "java.sql.Date" =>
writeType(dos, "date")
writeDate(dos, value.asInstanceOf[Date])
diff --git a/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
index a3df3788a..31cf97c12 100644
--- a/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
+++ b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
@@ -36,6 +36,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'g' => new java.lang.Long(readLong(dis))
case 'd' => new java.lang.Double(readDouble(dis))
case 'b' => new java.lang.Boolean(readBoolean(dis))
+ case 'm' => readDecimal(dis)
case 'c' => readString(dis)
case 'e' => readMap(dis)
case 'r' => readBytes(dis)
@@ -60,6 +61,10 @@ class SerDe(val tracker: JVMObjectTracker) {
in.readInt()
}
+ private def readDecimal(in: DataInputStream): BigDecimal = {
+ BigDecimal(readString(in))
+ }
+
private def readLong(in: DataInputStream): Long = {
in.readLong()
}
@@ -111,6 +116,11 @@ class SerDe(val tracker: JVMObjectTracker) {
(0 until len).map(_ => readInt(in)).toArray
}
+ private def readDecimalArr(in: DataInputStream): Array[BigDecimal] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readDecimal(in)).toArray
+ }
+
private def readLongArr(in: DataInputStream): Array[Long] = {
val len = readInt(in)
(0 until len).map(_ => readLong(in)).toArray
@@ -157,6 +167,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'b' => readBooleanArr(dis)
case 'j' => readStringArr(dis).map(x => tracker.getObject(x))
case 'r' => readBytesArr(dis)
+ case 'm' => readDecimalArr(dis)
case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
}
}
@@ -207,6 +218,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case "long" => dos.writeByte('g')
case "integer" => dos.writeByte('i')
case "logical" => dos.writeByte('b')
+ case "bigdecimal" => dos.writeByte('m')
case "date" => dos.writeByte('D')
case "time" => dos.writeByte('t')
case "raw" => dos.writeByte('r')
@@ -239,6 +251,9 @@ class SerDe(val tracker: JVMObjectTracker) {
case "boolean" | "java.lang.Boolean" =>
writeType(dos, "logical")
writeBoolean(dos, value.asInstanceOf[Boolean])
+ case "BigDecimal" | "java.math.BigDecimal" =>
+ writeType(dos, "bigdecimal")
+ writeString(dos, value.toString)
case "java.sql.Date" =>
writeType(dos, "date")
writeDate(dos, value.asInstanceOf[Date])