From 0d100ff64b6bb71336101ac7bfdf125666f156a6 Mon Sep 17 00:00:00 2001 From: Valentino Pinna Date: Fri, 26 Jul 2024 15:41:09 +0200 Subject: [PATCH] Fixed median operator --- .../median/ex_1.csv | 4 +- .../AggregateTransformationTest.java | 4 +- .../types/operators/AggregateOperator.java | 13 +-- .../types/operators/AnalyticOperator.java | 17 +-- .../impl/types/operators/MedianCollector.java | 109 ++++++++++++++++++ 5 files changed, 121 insertions(+), 26 deletions(-) create mode 100644 vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/MedianCollector.java diff --git a/vtl-bundles/vtl-coverage/src/test/resources/it/bancaditalia/oss/vtl/coverage/tests/10 Aggregate and Analytic operators/median/ex_1.csv b/vtl-bundles/vtl-coverage/src/test/resources/it/bancaditalia/oss/vtl/coverage/tests/10 Aggregate and Analytic operators/median/ex_1.csv index 9ce2322b3..250d9b01c 100644 --- a/vtl-bundles/vtl-coverage/src/test/resources/it/bancaditalia/oss/vtl/coverage/tests/10 Aggregate and Analytic operators/median/ex_1.csv +++ b/vtl-bundles/vtl-coverage/src/test/resources/it/bancaditalia/oss/vtl/coverage/tests/10 Aggregate and Analytic operators/median/ex_1.csv @@ -1,3 +1,3 @@ id_2,me_1,me_2 -XX,20,17 -YY,5,4 +XX,2,14 +YY,5,3 diff --git a/vtl-transform/src/test/java/it/bancaditalia/oss/vtl/impl/transform/aggregation/AggregateTransformationTest.java b/vtl-transform/src/test/java/it/bancaditalia/oss/vtl/impl/transform/aggregation/AggregateTransformationTest.java index fcdc0a4e9..2d0e32784 100644 --- a/vtl-transform/src/test/java/it/bancaditalia/oss/vtl/impl/transform/aggregation/AggregateTransformationTest.java +++ b/vtl-transform/src/test/java/it/bancaditalia/oss/vtl/impl/transform/aggregation/AggregateTransformationTest.java @@ -64,8 +64,8 @@ public static Stream test() Arguments.of(AVG, SAMPLE16, 3.025), Arguments.of(AVG, SAMPLE17, 14.18), Arguments.of(MEDIAN, SAMPLE5, 14L), - Arguments.of(MEDIAN, SAMPLE6, 24L), - Arguments.of(MEDIAN, SAMPLE16, 3.3), + Arguments.of(MEDIAN, SAMPLE6, 23L), + Arguments.of(MEDIAN, SAMPLE16, 2.2), Arguments.of(MEDIAN, SAMPLE17, 14.4), Arguments.of(MIN, SAMPLE5, 11L), Arguments.of(MIN, SAMPLE6, 21L), diff --git a/vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/AggregateOperator.java b/vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/AggregateOperator.java index 5ed92c916..2859de093 100644 --- a/vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/AggregateOperator.java +++ b/vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/AggregateOperator.java @@ -25,14 +25,15 @@ import static it.bancaditalia.oss.vtl.impl.types.domain.Domains.NUMBERDS; import static it.bancaditalia.oss.vtl.util.SerCollectors.collectingAndThen; import static it.bancaditalia.oss.vtl.util.SerCollectors.counting; +import static it.bancaditalia.oss.vtl.util.SerCollectors.filtering; import static it.bancaditalia.oss.vtl.util.SerCollectors.mapping; import static it.bancaditalia.oss.vtl.util.SerCollectors.maxBy; import static it.bancaditalia.oss.vtl.util.SerCollectors.minBy; +import static it.bancaditalia.oss.vtl.util.SerPredicate.not; import static java.util.stream.Collector.Characteristics.CONCURRENT; import static java.util.stream.Collector.Characteristics.UNORDERED; import java.math.BigDecimal; -import java.util.Arrays; import java.util.EnumSet; import it.bancaditalia.oss.vtl.impl.types.data.BigDecimalValue; @@ -46,21 +47,13 @@ import it.bancaditalia.oss.vtl.util.SerDoubleSumAvgCount; import it.bancaditalia.oss.vtl.util.SerFunction; import it.bancaditalia.oss.vtl.util.SerSupplier; -import it.bancaditalia.oss.vtl.util.Utils; public enum AggregateOperator { COUNT(() -> collectingAndThen(counting(), IntegerValue::of)), SUM(() -> getSummingCollector()), AVG(() -> getAveragingCollector()), - MEDIAN(() -> collectingAndThen(SerCollectors.toSet(), s -> { - ScalarValue[] array = s.toArray(new ScalarValue[s.size()]); - if (Utils.SEQUENTIAL) - Arrays.sort(array); - else - Arrays.parallelSort(array); - return array[array.length / 2]; - })), + MEDIAN(() -> collectingAndThen(filtering(not(NullValue.class::isInstance), new MedianCollector(getSVClass())), opt -> opt.orElse(NullValue.instance(NULLDS)))), MIN(() -> collectingAndThen(minBy(getSVClass(), ScalarValue::compareTo), opt -> opt.orElse(NullValue.instance(NULLDS)))), MAX(() -> collectingAndThen(maxBy(getSVClass(), ScalarValue::compareTo), opt -> opt.orElse(NullValue.instance(NULLDS)))), VAR_POP(() -> collectingAndThen(SerCollectors.mapping(v -> (Number) v.get(), varianceCollector(acu -> acu[2] / (acu[0] + 1))), NumberValueImpl::createNumberValue)), diff --git a/vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/AnalyticOperator.java b/vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/AnalyticOperator.java index 4de46fc9e..e53eeb2b2 100644 --- a/vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/AnalyticOperator.java +++ b/vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/AnalyticOperator.java @@ -23,16 +23,18 @@ import static it.bancaditalia.oss.vtl.impl.types.domain.Domains.NUMBERDS; import static it.bancaditalia.oss.vtl.impl.types.operators.AggregateOperator.getAveragingCollector; import static it.bancaditalia.oss.vtl.impl.types.operators.AggregateOperator.getSummingCollector; +import static it.bancaditalia.oss.vtl.impl.types.operators.MedianCollector.medianCollector; import static it.bancaditalia.oss.vtl.util.SerCollectors.collectingAndThen; import static it.bancaditalia.oss.vtl.util.SerCollectors.counting; +import static it.bancaditalia.oss.vtl.util.SerCollectors.filtering; import static it.bancaditalia.oss.vtl.util.SerCollectors.firstValue; import static it.bancaditalia.oss.vtl.util.SerCollectors.lastValue; import static it.bancaditalia.oss.vtl.util.SerCollectors.mapping; import static it.bancaditalia.oss.vtl.util.SerCollectors.maxBy; import static it.bancaditalia.oss.vtl.util.SerCollectors.minBy; +import static it.bancaditalia.oss.vtl.util.SerPredicate.not; import static java.util.stream.Collector.Characteristics.UNORDERED; -import java.util.Arrays; import java.util.EnumSet; import it.bancaditalia.oss.vtl.impl.types.data.IntegerValue; @@ -42,26 +44,17 @@ import it.bancaditalia.oss.vtl.model.domain.ValueDomain; import it.bancaditalia.oss.vtl.model.domain.ValueDomainSubset; import it.bancaditalia.oss.vtl.util.SerCollector; -import it.bancaditalia.oss.vtl.util.SerCollectors; import it.bancaditalia.oss.vtl.util.SerFunction; -import it.bancaditalia.oss.vtl.util.Utils; public enum AnalyticOperator { COUNT(domain -> collectingAndThen(counting(), IntegerValue::of)), SUM(domain -> getSummingCollector()), AVG(domain -> getAveragingCollector()), - MEDIAN(domain -> collectingAndThen(SerCollectors.toSet(), s -> { - ScalarValue[] array = s.toArray(new ScalarValue[s.size()]); - if (Utils.SEQUENTIAL) - Arrays.sort(array); - else - Arrays.parallelSort(array); - return array[array.length / 2]; - })), + MEDIAN(domain -> collectingAndThen(filtering(not(NullValue.class::isInstance), medianCollector(domain.getValueClass())), opt -> opt.orElse(NullValue.unqualifiedInstance(domain)))), MIN(domain -> collectingAndThen(minBy(domain.getValueClass(), ScalarValue::compareTo), opt -> opt.orElse(NullValue.unqualifiedInstance(domain)))), MAX(domain -> collectingAndThen(maxBy(domain.getValueClass(), ScalarValue::compareTo), opt -> opt.orElse(NullValue.unqualifiedInstance(domain)))), - VAR_POP(domain -> collectingAndThen(SerCollectors.mapping(v -> (Number) v.get(), varianceCollector(acu -> acu[2] / (acu[0] + 1))), NumberValueImpl::createNumberValue)), + VAR_POP(domain -> collectingAndThen(mapping(v -> (Number) v.get(), varianceCollector(acu -> acu[2] / (acu[0] + 1))), NumberValueImpl::createNumberValue)), VAR_SAMP(domain -> collectingAndThen(mapping(v -> (Number) v.get(), varianceCollector(acu -> acu[2] / acu[0])), NumberValueImpl::createNumberValue)), STDDEV_POP(domain -> collectingAndThen(VAR_POP.getReducer(NUMBERDS), dv -> createNumberValue(Math.sqrt((Double) dv.get())))), STDDEV_SAMP(domain -> collectingAndThen(VAR_SAMP.getReducer(domain), dv -> createNumberValue(Math.sqrt((Double) dv.get())))), diff --git a/vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/MedianCollector.java b/vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/MedianCollector.java new file mode 100644 index 000000000..9eb11d52c --- /dev/null +++ b/vtl-types/src/main/java/it/bancaditalia/oss/vtl/impl/types/operators/MedianCollector.java @@ -0,0 +1,109 @@ +/* + * Copyright © 2020 Banca D'Italia + * + * Licensed under the EUPL, Version 1.2 (the "License"); + * You may not use this work except in compliance with the + * License. + * You may obtain a copy of the License at: + * + * https://joinup.ec.europa.eu/sites/default/files/custom-page/attachment/2020-03/EUPL-1.2%20EN.txt + * + * Unless required by applicable law or agreed to in + * writing, software distributed under the License is + * distributed on an "AS IS" basis, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. + * + * See the License for the specific language governing + * permissions and limitations under the License. + */ +package it.bancaditalia.oss.vtl.impl.types.operators; + +import static java.util.Collections.reverseOrder; + +import java.io.Serializable; +import java.util.EnumSet; +import java.util.Optional; +import java.util.PriorityQueue; + +import it.bancaditalia.oss.vtl.model.data.ScalarValue; +import it.bancaditalia.oss.vtl.util.SerCollector; + +public class MedianCollector extends SerCollector, MedianCollector.MedianAcc, Optional>> +{ + private static final long serialVersionUID = 1L; + + public static class MedianAcc implements Serializable + { + private static final long serialVersionUID = 1L; + + private final Class repr; + private PriorityQueue> left = new PriorityQueue<>(reverseOrder()); + private PriorityQueue> right = new PriorityQueue<>(); + + public MedianAcc(Class repr) + { + this.repr = repr; + } + + public void accumulate(ScalarValue value) + { + if (left.isEmpty() || value.compareTo(left.peek()) <= 0) + { + left.add(value); + if (left.size() > right.size() + 1) + right.add(left.poll()); + } + else + { + right.add(value); + if (right.size() > left.size()) + left.add(right.poll()); + } + } + + public MedianAcc merge(MedianAcc other) + { + left.addAll(other.left); + right.addAll(other.right); + + while (left.size() > right.size() + 1) + right.add(left.poll()); + while (right.size() > left.size()) + left.add(right.poll()); + + return this; + } + + public Optional> finish() + { + while (left.size() > right.size() + 1) + right.add(left.poll()); + while (right.size() > left.size()) + left.add(right.poll()); + + ScalarValue l = left.peek(); + ScalarValue r = right.peek(); + + if (l != null && r != null) + return Optional.of(l.compareTo(r) <= 0 ? l : r); + else + return Optional.ofNullable(l); + } + + public Class getRepr() + { + return repr; + } + } + + public MedianCollector(Class repr) + { + super(() -> new MedianAcc(repr), MedianAcc::accumulate, MedianAcc::merge, MedianAcc::finish, EnumSet.noneOf(Characteristics.class)); + } + + public static SerCollector, ?, Optional>> medianCollector(Class repr) + { + return new MedianCollector(repr); + } +}