diff --git a/build.sbt b/build.sbt index 0c837cb5..f65ef709 100644 --- a/build.sbt +++ b/build.sbt @@ -58,11 +58,11 @@ lazy val versions = new { val commons_math = "3.5" val joda_time = "2.9.4" val httpclient = "4.3.2" // Note that newer versions need to be configured - val spark = sys.props.getOrElse("spark.version", default = "2.3.0") + val spark = sys.props.getOrElse("spark.version", default = "2.4.0") val scalatest = "2.2.4" val scalacheck = "1.12.6" val grizzled_slf4j = "1.3.0" - val arrow = "0.8.0" + val arrow = "0.10.0" val jackson_module = "2.7.2" } diff --git a/src/main/scala/com/twosigma/flint/arrow/ArrowWriter.scala b/src/main/scala/com/twosigma/flint/arrow/ArrowWriter.scala index 2031b609..e186b12b 100644 --- a/src/main/scala/com/twosigma/flint/arrow/ArrowWriter.scala +++ b/src/main/scala/com/twosigma/flint/arrow/ArrowWriter.scala @@ -81,7 +81,7 @@ object ArrowWriter { case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) - case (StructType(_), vector: NullableMapVector) => + case (StructType(_), vector: StructVector) => val children = (0 until vector.size()).map { ordinal => createFieldWriter(vector.getChildByOrdinal(ordinal)) } @@ -334,7 +334,7 @@ private[arrow] class ArrayWriter( } private[arrow] class StructWriter( - val valueVector: NullableMapVector, + val valueVector: StructVector, children: Array[ArrowFieldWriter] ) extends ArrowFieldWriter { diff --git a/src/main/scala/com/twosigma/flint/arrow/BufferHolder.java b/src/main/scala/com/twosigma/flint/arrow/BufferHolder.java new file mode 100644 index 00000000..eb975f52 --- /dev/null +++ b/src/main/scala/com/twosigma/flint/arrow/BufferHolder.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.twosigma.flint.arrow; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; + +/** + * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and + * automatically re-point the unsafe row to it. + * + * This class can be used to build a one-pass unsafe row writing program, i.e. data will be written + * to the data buffer directly and no extra copy is needed. There should be only one instance of + * this class per writing program, so that the memory segment/data buffer can be reused. Note that + * for each incoming record, we should call `reset` of BufferHolder instance before write the record + * and reuse the data buffer. + * + * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update + * the size of the result row, after writing a record to the buffer. However, we can skip this step + * if the fields of row are all fixed-length, as the size of result row is also fixed. + */ +public class BufferHolder { + + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + + public byte[] buffer; + public int cursor = Platform.BYTE_ARRAY_OFFSET; + private final UnsafeRow row; + private final int fixedSize; + + public BufferHolder(UnsafeRow row) { + this(row, 64); + } + + public BufferHolder(UnsafeRow row, int initialSize) { + int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); + if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { + throw new UnsupportedOperationException( + "Cannot create BufferHolder for input UnsafeRow because there are " + + "too many fields (number of fields: " + row.numFields() + ")"); + } + this.fixedSize = bitsetWidthInBytes + 8 * row.numFields(); + this.buffer = new byte[fixedSize + initialSize]; + this.row = row; + this.row.pointTo(buffer, buffer.length); + } + + /** + * Grows the buffer by at least neededSize and points the row to the buffer. + */ + public void grow(int neededSize) { + if (neededSize > ARRAY_MAX - totalSize()) { + throw new UnsupportedOperationException( + "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + + "exceeds size limitation " + ARRAY_MAX); + } + final int length = totalSize() + neededSize; + if (buffer.length < length) { + // This will not happen frequently, because the buffer is re-used. + int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; + final byte[] tmp = new byte[newLength]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + buffer = tmp; + row.pointTo(buffer, buffer.length); + } + } + + public void reset() { + cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; + } + + public int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } +} \ No newline at end of file diff --git a/src/main/scala/com/twosigma/flint/arrow/UnsafeRowWriter.java b/src/main/scala/com/twosigma/flint/arrow/UnsafeRowWriter.java new file mode 100644 index 00000000..bc8e3e8b --- /dev/null +++ b/src/main/scala/com/twosigma/flint/arrow/UnsafeRowWriter.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.twosigma.flint.arrow; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write data into global row buffer using `UnsafeRow` format. + * + * It will remember the offset of row buffer which it starts to write, and move the cursor of row + * buffer while writing. If new data(can be the input record if this is the outermost writer, or + * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be + * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * `startingOffset` and clear out null bits. + * + * Note that if this is the outermost writer, which means we will always write from the very + * beginning of the global row buffer, we don't need to update `startingOffset` and can just call + * `zeroOutNullBytes` before writing new data. + */ +public class UnsafeRowWriter{ + + private final BufferHolder holder; + // The offset of the global buffer where we start to write this row. + private int startingOffset; + private final int nullBitsSize; + private final int fixedSize; + + public UnsafeRowWriter(BufferHolder holder, int numFields) { + this.holder = holder; + this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); + this.fixedSize = nullBitsSize + 8 * numFields; + this.startingOffset = holder.cursor; + } + + /** + * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null + * bits. This should be called before we write a new nested struct to the row buffer. + */ + public void reset() { + this.startingOffset = holder.cursor; + + // grow the global buffer to make sure it has enough space to write fixed-length data. + holder.grow(fixedSize); + holder.cursor += fixedSize; + + zeroOutNullBytes(); + } + + /** + * Clears out null bits. This should be called before we write a new row to row buffer. + */ + public void zeroOutNullBytes() { + for (int i = 0; i < nullBitsSize; i += 8) { + Platform.putLong(holder.buffer, startingOffset + i, 0L); + } + } + + private void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + } + } + + public BufferHolder holder() { return holder; } + + public boolean isNullAt(int ordinal) { + return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + } + + public void setNullAt(int ordinal) { + BitSetMethods.set(holder.buffer, startingOffset, ordinal); + Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); + } + + public long getFieldOffset(int ordinal) { + return startingOffset + nullBitsSize + 8 * ordinal; + } + + public void setOffsetAndSize(int ordinal, long size) { + setOffsetAndSize(ordinal, holder.cursor, size); + } + + public void setOffsetAndSize(int ordinal, long currentCursor, long size) { + final long relativeOffset = currentCursor - startingOffset; + final long fieldOffset = getFieldOffset(ordinal); + final long offsetAndSize = (relativeOffset << 32) | size; + + Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); + } + + public void write(int ordinal, boolean value) { + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putBoolean(holder.buffer, offset, value); + } + + public void write(int ordinal, byte value) { + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putByte(holder.buffer, offset, value); + } + + public void write(int ordinal, short value) { + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putShort(holder.buffer, offset, value); + } + + public void write(int ordinal, int value) { + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putInt(holder.buffer, offset, value); + } + + public void write(int ordinal, long value) { + Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putFloat(holder.buffer, offset, value); + } + + public void write(int ordinal, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, Decimal input, int precision, int scale) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + // make sure Decimal object has the same scale as DecimalType + if (input.changePrecision(precision, scale)) { + Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + } else { + setNullAt(ordinal); + } + } else { + // grow the global buffer before writing data. + holder.grow(16); + + // zero-out the bytes + Platform.putLong(holder.buffer, holder.cursor, 0L); + Platform.putLong(holder.buffer, holder.cursor + 8, 0L); + + // Make sure Decimal object has the same scale as DecimalType. + // Note that we may pass in null Decimal object to set null for it. + if (input == null || !input.changePrecision(precision, scale)) { + BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // keep the offset for future update + setOffsetAndSize(ordinal, 0L); + } else { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + assert bytes.length <= 16; + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + setOffsetAndSize(ordinal, bytes.length); + } + + // move the cursor forward. + holder.cursor += 16; + } + } + + public void write(int ordinal, UTF8String input) { + final int numBytes = input.numBytes(); + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + holder.grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + input.writeToMemory(holder.buffer, holder.cursor); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + holder.cursor += roundedSize; + } + + public void write(int ordinal, byte[] input) { + write(ordinal, input, 0, input.length); + } + + public void write(int ordinal, byte[] input, int offset, int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + holder.grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, + holder.buffer, holder.cursor, numBytes); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + holder.cursor += roundedSize; + } + + public void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + holder.grow(16); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(holder.buffer, holder.cursor, input.months); + Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); + + setOffsetAndSize(ordinal, 16); + + // move the cursor forward. + holder.cursor += 16; + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/sql/TimestampCast.scala b/src/main/scala/org/apache/spark/sql/TimestampCast.scala index bc335bff..0c2521bf 100644 --- a/src/main/scala/org/apache/spark/sql/TimestampCast.scala +++ b/src/main/scala/org/apache/spark/sql/TimestampCast.scala @@ -16,7 +16,8 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.codegen.{ CodegenContext, ExprCode } +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.{ Expression, NullIntolerant, UnaryExpression } import org.apache.spark.sql.types.{ DataType, LongType, TimestampType } @@ -61,10 +62,10 @@ trait TimestampCast extends UnaryExpression with NullIntolerant { /** Copied and modified from org/apache/spark/sql/catalyst/expressions/Cast.scala */ private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String, - resultPrim: String, resultNull: String, resultType: DataType): String = { - s""" + resultPrim: String, resultNull: String, resultType: DataType): Block = { + code""" boolean $resultNull = $childNull; - ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)}; + ${JavaCode.javaType(resultType)} $resultPrim = ${CodeGenerator.defaultValue(resultType)}; if (!${childNull}) { $resultPrim = (long) ${cast(childPrim)}; }