Skip to content

Commit

Permalink
Fixed median operator
Browse files Browse the repository at this point in the history
  • Loading branch information
vpinna80 committed Jul 26, 2024
1 parent 5a207c7 commit 0d100ff
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
id_2,me_1,me_2
XX,20,17
YY,5,4
XX,2,14
YY,5,3
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ public static Stream<Arguments> test()
Arguments.of(AVG, SAMPLE16, 3.025),
Arguments.of(AVG, SAMPLE17, 14.18),
Arguments.of(MEDIAN, SAMPLE5, 14L),
Arguments.of(MEDIAN, SAMPLE6, 24L),
Arguments.of(MEDIAN, SAMPLE16, 3.3),
Arguments.of(MEDIAN, SAMPLE6, 23L),
Arguments.of(MEDIAN, SAMPLE16, 2.2),
Arguments.of(MEDIAN, SAMPLE17, 14.4),
Arguments.of(MIN, SAMPLE5, 11L),
Arguments.of(MIN, SAMPLE6, 21L),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
import static it.bancaditalia.oss.vtl.impl.types.domain.Domains.NUMBERDS;
import static it.bancaditalia.oss.vtl.util.SerCollectors.collectingAndThen;
import static it.bancaditalia.oss.vtl.util.SerCollectors.counting;
import static it.bancaditalia.oss.vtl.util.SerCollectors.filtering;
import static it.bancaditalia.oss.vtl.util.SerCollectors.mapping;
import static it.bancaditalia.oss.vtl.util.SerCollectors.maxBy;
import static it.bancaditalia.oss.vtl.util.SerCollectors.minBy;
import static it.bancaditalia.oss.vtl.util.SerPredicate.not;
import static java.util.stream.Collector.Characteristics.CONCURRENT;
import static java.util.stream.Collector.Characteristics.UNORDERED;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.EnumSet;

import it.bancaditalia.oss.vtl.impl.types.data.BigDecimalValue;
Expand All @@ -46,21 +47,13 @@
import it.bancaditalia.oss.vtl.util.SerDoubleSumAvgCount;
import it.bancaditalia.oss.vtl.util.SerFunction;
import it.bancaditalia.oss.vtl.util.SerSupplier;
import it.bancaditalia.oss.vtl.util.Utils;

public enum AggregateOperator
{
COUNT(() -> collectingAndThen(counting(), IntegerValue::of)),
SUM(() -> getSummingCollector()),
AVG(() -> getAveragingCollector()),
MEDIAN(() -> collectingAndThen(SerCollectors.toSet(), s -> {
ScalarValue<?, ?, ?, ?>[] array = s.toArray(new ScalarValue<?, ?, ?, ?>[s.size()]);
if (Utils.SEQUENTIAL)
Arrays.sort(array);
else
Arrays.parallelSort(array);
return array[array.length / 2];
})),
MEDIAN(() -> collectingAndThen(filtering(not(NullValue.class::isInstance), new MedianCollector(getSVClass())), opt -> opt.orElse(NullValue.instance(NULLDS)))),
MIN(() -> collectingAndThen(minBy(getSVClass(), ScalarValue::compareTo), opt -> opt.orElse(NullValue.instance(NULLDS)))),
MAX(() -> collectingAndThen(maxBy(getSVClass(), ScalarValue::compareTo), opt -> opt.orElse(NullValue.instance(NULLDS)))),
VAR_POP(() -> collectingAndThen(SerCollectors.mapping(v -> (Number) v.get(), varianceCollector(acu -> acu[2] / (acu[0] + 1))), NumberValueImpl::createNumberValue)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@
import static it.bancaditalia.oss.vtl.impl.types.domain.Domains.NUMBERDS;
import static it.bancaditalia.oss.vtl.impl.types.operators.AggregateOperator.getAveragingCollector;
import static it.bancaditalia.oss.vtl.impl.types.operators.AggregateOperator.getSummingCollector;
import static it.bancaditalia.oss.vtl.impl.types.operators.MedianCollector.medianCollector;
import static it.bancaditalia.oss.vtl.util.SerCollectors.collectingAndThen;
import static it.bancaditalia.oss.vtl.util.SerCollectors.counting;
import static it.bancaditalia.oss.vtl.util.SerCollectors.filtering;
import static it.bancaditalia.oss.vtl.util.SerCollectors.firstValue;
import static it.bancaditalia.oss.vtl.util.SerCollectors.lastValue;
import static it.bancaditalia.oss.vtl.util.SerCollectors.mapping;
import static it.bancaditalia.oss.vtl.util.SerCollectors.maxBy;
import static it.bancaditalia.oss.vtl.util.SerCollectors.minBy;
import static it.bancaditalia.oss.vtl.util.SerPredicate.not;
import static java.util.stream.Collector.Characteristics.UNORDERED;

import java.util.Arrays;
import java.util.EnumSet;

import it.bancaditalia.oss.vtl.impl.types.data.IntegerValue;
Expand All @@ -42,26 +44,17 @@
import it.bancaditalia.oss.vtl.model.domain.ValueDomain;
import it.bancaditalia.oss.vtl.model.domain.ValueDomainSubset;
import it.bancaditalia.oss.vtl.util.SerCollector;
import it.bancaditalia.oss.vtl.util.SerCollectors;
import it.bancaditalia.oss.vtl.util.SerFunction;
import it.bancaditalia.oss.vtl.util.Utils;

public enum AnalyticOperator
{
COUNT(domain -> collectingAndThen(counting(), IntegerValue::of)),
SUM(domain -> getSummingCollector()),
AVG(domain -> getAveragingCollector()),
MEDIAN(domain -> collectingAndThen(SerCollectors.toSet(), s -> {
ScalarValue<?, ?, ?, ?>[] array = s.toArray(new ScalarValue<?, ?, ?, ?>[s.size()]);
if (Utils.SEQUENTIAL)
Arrays.sort(array);
else
Arrays.parallelSort(array);
return array[array.length / 2];
})),
MEDIAN(domain -> collectingAndThen(filtering(not(NullValue.class::isInstance), medianCollector(domain.getValueClass())), opt -> opt.orElse(NullValue.unqualifiedInstance(domain)))),
MIN(domain -> collectingAndThen(minBy(domain.getValueClass(), ScalarValue::compareTo), opt -> opt.orElse(NullValue.unqualifiedInstance(domain)))),
MAX(domain -> collectingAndThen(maxBy(domain.getValueClass(), ScalarValue::compareTo), opt -> opt.orElse(NullValue.unqualifiedInstance(domain)))),
VAR_POP(domain -> collectingAndThen(SerCollectors.mapping(v -> (Number) v.get(), varianceCollector(acu -> acu[2] / (acu[0] + 1))), NumberValueImpl::createNumberValue)),
VAR_POP(domain -> collectingAndThen(mapping(v -> (Number) v.get(), varianceCollector(acu -> acu[2] / (acu[0] + 1))), NumberValueImpl::createNumberValue)),
VAR_SAMP(domain -> collectingAndThen(mapping(v -> (Number) v.get(), varianceCollector(acu -> acu[2] / acu[0])), NumberValueImpl::createNumberValue)),
STDDEV_POP(domain -> collectingAndThen(VAR_POP.getReducer(NUMBERDS), dv -> createNumberValue(Math.sqrt((Double) dv.get())))),
STDDEV_SAMP(domain -> collectingAndThen(VAR_SAMP.getReducer(domain), dv -> createNumberValue(Math.sqrt((Double) dv.get())))),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright © 2020 Banca D'Italia
*
* Licensed under the EUPL, Version 1.2 (the "License");
* You may not use this work except in compliance with the
* License.
* You may obtain a copy of the License at:
*
* https://joinup.ec.europa.eu/sites/default/files/custom-page/attachment/2020-03/EUPL-1.2%20EN.txt
*
* Unless required by applicable law or agreed to in
* writing, software distributed under the License is
* distributed on an "AS IS" basis,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied.
*
* See the License for the specific language governing
* permissions and limitations under the License.
*/
package it.bancaditalia.oss.vtl.impl.types.operators;

import static java.util.Collections.reverseOrder;

import java.io.Serializable;
import java.util.EnumSet;
import java.util.Optional;
import java.util.PriorityQueue;

import it.bancaditalia.oss.vtl.model.data.ScalarValue;
import it.bancaditalia.oss.vtl.util.SerCollector;

public class MedianCollector extends SerCollector<ScalarValue<?, ?, ?, ?>, MedianCollector.MedianAcc, Optional<ScalarValue<?, ?, ?, ?>>>
{
private static final long serialVersionUID = 1L;

public static class MedianAcc implements Serializable
{
private static final long serialVersionUID = 1L;

private final Class<?> repr;
private PriorityQueue<ScalarValue<?, ?, ?, ?>> left = new PriorityQueue<>(reverseOrder());
private PriorityQueue<ScalarValue<?, ?, ?, ?>> right = new PriorityQueue<>();

public MedianAcc(Class<?> repr)
{
this.repr = repr;
}

public void accumulate(ScalarValue<?, ?, ?, ?> value)
{
if (left.isEmpty() || value.compareTo(left.peek()) <= 0)
{
left.add(value);
if (left.size() > right.size() + 1)
right.add(left.poll());
}
else
{
right.add(value);
if (right.size() > left.size())
left.add(right.poll());
}
}

public MedianAcc merge(MedianAcc other)
{
left.addAll(other.left);
right.addAll(other.right);

while (left.size() > right.size() + 1)
right.add(left.poll());
while (right.size() > left.size())
left.add(right.poll());

return this;
}

public Optional<ScalarValue<?, ?, ?, ?>> finish()
{
while (left.size() > right.size() + 1)
right.add(left.poll());
while (right.size() > left.size())
left.add(right.poll());

ScalarValue<?, ?, ?, ?> l = left.peek();
ScalarValue<?, ?, ?, ?> r = right.peek();

if (l != null && r != null)
return Optional.of(l.compareTo(r) <= 0 ? l : r);
else
return Optional.ofNullable(l);
}

public Class<?> getRepr()
{
return repr;
}
}

public MedianCollector(Class<?> repr)
{
super(() -> new MedianAcc(repr), MedianAcc::accumulate, MedianAcc::merge, MedianAcc::finish, EnumSet.noneOf(Characteristics.class));
}

public static SerCollector<ScalarValue<?, ?, ?, ?>, ?, Optional<ScalarValue<?, ?, ?, ?>>> medianCollector(Class<?> repr)
{
return new MedianCollector(repr);
}
}

0 comments on commit 0d100ff

Please sign in to comment.