Skip to content

Commit

Permalink
better-spliterator: Better spliterator
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Lavrukov committed May 15, 2024
1 parent edaa8c4 commit e80facb
Show file tree
Hide file tree
Showing 5 changed files with 401 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package tech.ydb.yoj.repository.ydb.spliterator;

import tech.ydb.proto.ValueProtos;

import java.util.List;

@FunctionalInterface
public interface ResultConverter<V> {
V convert(List<ValueProtos.Column> columns, ValueProtos.Value value);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package tech.ydb.yoj.repository.ydb.spliterator;

import tech.ydb.proto.ValueProtos;
import tech.ydb.table.result.ResultSetReader;
import tech.ydb.yoj.repository.ydb.client.YdbConverter;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;

/* package */ final class ResultSetIterator<V> implements Iterator<V> {
private final ResultSetReader resultSet;
private final List<ValueProtos.Column> columns;

private final ResultConverter<V> converter;

private int position = 0;

public ResultSetIterator(ResultConverter<V> converter, ResultSetReader resultSet) {
this.converter = converter;
this.resultSet = resultSet;

if (resultSet.getRowCount() > 0) {
columns = getColumns(resultSet);
} else {
columns = new ArrayList<>();
}

this.resultSet.setRowIndex(0);
}

@Override
public boolean hasNext() {
return position < resultSet.getRowCount();
}

@Override
public V next() {
if (!hasNext()) {
throw new NoSuchElementException();
}

ValueProtos.Value value = buildValue(position++);

return converter.convert(columns, value);
}

private ValueProtos.Value buildValue(int rowIndex) {
resultSet.setRowIndex(rowIndex);
ValueProtos.Value.Builder value = ValueProtos.Value.newBuilder();
for (int col = 0; col < columns.size(); col++) {
value.addItems(YdbConverter.convertValueToProto(resultSet.getColumn(col)));
}
return value.build();
}

private static List<ValueProtos.Column> getColumns(ResultSetReader resultSet) {
resultSet.setRowIndex(0);
List<ValueProtos.Column> result = new ArrayList<>();
for (int i = 0; i < resultSet.getColumnCount(); i++) {
result.add(ValueProtos.Column.newBuilder()
.setName(resultSet.getColumnName(i))
.build());
}
return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package tech.ydb.yoj.repository.ydb.spliterator;

import tech.ydb.table.result.ResultSetReader;
import tech.ydb.yoj.ExperimentalApi;

import java.time.Duration;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

@ExperimentalApi(issue = "https://github.com/ydb-platform/yoj-project/issues/42")
public final class YdbSpliterator<V> implements Spliterator<V> {
private static final Duration DEFAULT_STREAM_WORK_TIMEOUT = Duration.ofMinutes(5);

private final ResultConverter<V> converter;

private final int flags;
private final YdbSpliteratorQueue<ResultSetReader> queue;

private ResultSetIterator<V> resultIterator;

private volatile boolean closed = false;

public YdbSpliterator(ResultConverter<V> converter, boolean isOrdered) {
this(converter, isOrdered, DEFAULT_STREAM_WORK_TIMEOUT);
}

private YdbSpliterator(ResultConverter<V> converter, boolean isOrdered, Duration streamWorkTimeout) {
this.converter = converter;
this.flags = (isOrdered ? ORDERED : 0) | NONNULL;
this.queue = new YdbSpliteratorQueue<>(1, streamWorkTimeout);
}

// Correct way to create stream with YdbSpliterator. onClose call is important for avoid supplier thread leak.
public Stream<V> createStream() {
return StreamSupport.stream(this, false).onClose(this::close);
}

@Override
public boolean tryAdvance(Consumer<? super V> action) {
if (closed) {
return false;
}

if (resultIterator == null || !resultIterator.hasNext()) {
ResultSetReader resultSet = queue.poll();
if (resultSet == null) {
closed = true;
return false;
}
resultIterator = new ResultSetIterator<>(converter, resultSet);
}

V value = resultIterator.next();

action.accept(value);

return true;
}

public void close() {
closed = true;
queue.close();
}

@Override
public Spliterator<V> trySplit() {
return null;
}

@Override
public long estimateSize() {
return Long.MAX_VALUE;
}

@Override
public long getExactSizeIfKnown() {
return -1;
}

@Override
public int characteristics() {
return flags;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package tech.ydb.yoj.repository.ydb.spliterator;

import com.google.common.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tech.ydb.yoj.ExperimentalApi;
import tech.ydb.yoj.repository.db.exception.DeadlineExceededException;
import tech.ydb.yoj.repository.db.exception.QueryInterruptedException;

import java.time.Duration;
import java.util.ArrayDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

@ExperimentalApi(issue = "https://github.com/ydb-platform/yoj-project/issues/42")
/* package */ final class YdbSpliteratorQueue<V> {
private static final Logger log = LoggerFactory.getLogger(YdbSpliteratorQueue.class);

private static final SupplierStatus UNDONE_SUPPLIER_STATUS = () -> false;

private final int maxQueueSize;
private final ArrayDeque<V> queue;
private final long streamWorkDeadlineNanos;

private final Lock lock = new ReentrantLock();
private final Condition newElement = lock.newCondition();
private final Condition queueIsNotFull = lock.newCondition();

private SupplierStatus supplierStatus = () -> false;
private boolean closed = false;

public YdbSpliteratorQueue(int maxQueueSize, Duration streamWorkTimeout) {
Preconditions.checkArgument(maxQueueSize > 0, "maxQueueSize must be greater than 0");
this.maxQueueSize = maxQueueSize;
this.queue = new ArrayDeque<>(maxQueueSize);
this.streamWorkDeadlineNanos = System.nanoTime() + TimeUnit.NANOSECONDS.toNanos(saturatedToNanos(streamWorkTimeout));
}

public boolean onNext(V value) {
Preconditions.checkState(!supplierStatus.equals(UNDONE_SUPPLIER_STATUS),
"can't call onNext after supplierDone"
);

lock.lock();
try {
if (closed) {
return false;
}

// Only one supplier is possible, queue can't be full in this situation
queue.add(value);

newElement.signal();

if (queue.size() == maxQueueSize) {
try {
if (!queueIsNotFull.await(calculateTimeout(), TimeUnit.NANOSECONDS)) {
throw new OfferDeadlineExceededException();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new QueryInterruptedException("Supplier thread interrupted", e);
}

if (closed) {
return false;
}
}
} finally {
lock.unlock();
}

return true;
}

// (supplier thread) Send knowledge to stream when data is over.
public void supplierDone(SupplierStatus status) {
lock.lock();
try {
if (closed) {
return;
}

supplierStatus = status;

newElement.signal();
} finally {
lock.unlock();
}
}

public boolean isClosed() {
lock.lock();
try {
return closed;
} finally {
lock.unlock();
}
}

public V poll() {
lock.lock();
try {
if (closed || supplierStatus.isDone()) {
return null;
}

if (queue.isEmpty()) {
try {
if (!newElement.await(calculateTimeout(), TimeUnit.NANOSECONDS)) {
log.warn("Supplier thread was closed because consumer didn't poll an element of stream on timeout");
throw new DeadlineExceededException("Stream deadline exceeded on poll");
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new QueryInterruptedException("Consumer thread interrupted", e);
}

if (closed || supplierStatus.isDone()) {
return null;
}
}

V value = queue.pop();

queueIsNotFull.signal();

return value;
} finally {
lock.unlock();
}
}

public void close() {
lock.lock();
try {
if (closed) {
return;
}

closed = true;

queueIsNotFull.signal();
newElement.signalAll();
} finally {
lock.unlock();
}
}

private long calculateTimeout() {
return TimeUnit.NANOSECONDS.toNanos(streamWorkDeadlineNanos - System.nanoTime());
}

public static final class OfferDeadlineExceededException extends RuntimeException {
}

// copy-paste from com.google.common.util.concurrent.Uninterruptibles
private static long saturatedToNanos(Duration duration) {
try {
return duration.toNanos();
} catch (ArithmeticException ignore) {
return duration.isNegative() ? -9223372036854775808L : 9223372036854775807L;
}
}

@FunctionalInterface
public interface SupplierStatus {
boolean isDone();
}
}
Loading

0 comments on commit e80facb

Please sign in to comment.