Skip to content

Commit

Permalink
Fixes to CSV and SDMX Environments
Browse files Browse the repository at this point in the history
  • Loading branch information
vpinna80 committed Oct 25, 2023
1 parent 9be3479 commit 1ca985e
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import it.bancaditalia.oss.vtl.config.ConfigurationManager;
import it.bancaditalia.oss.vtl.config.ConfigurationManagerFactory;
import it.bancaditalia.oss.vtl.environment.Environment;
import it.bancaditalia.oss.vtl.exceptions.VTLMissingComponentsException;
import it.bancaditalia.oss.vtl.exceptions.VTLNestedException;
import it.bancaditalia.oss.vtl.impl.environment.util.CSVParseUtils;
import it.bancaditalia.oss.vtl.impl.environment.util.ProgressWindow;
Expand Down Expand Up @@ -188,10 +189,13 @@ protected Stream<DataPoint> streamFileName(String fileName, String alias)
}
else
{
DataSetMetadata structure = maybeStructure;

metadata = Arrays.stream(reader.readLine().split(","))
.map(maybeStructure::getComponent)
.map(Optional::get)
.map(toEntryWithValue(maybeStructure::getComponent))
.map(e -> e.getValue().orElseThrow(() -> new VTLMissingComponentsException(e.getKey(), structure)))
.collect(toList());

masks = metadata.stream()
.map(toEntryWithValue(DataStructureComponent::getDomain))
.map(keepingKey(ValueDomainSubset::getName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ public synchronized DataPoint next()
builder.add(c, c.cast(StringValue.of(a.getCode())));
}

DataStructureComponent<Measure, ?, ?> measure = structure.getComponents(Measure.class).iterator().next();
DataStructureComponent<Measure, ?, ?> measure = structure.getMeasures().iterator().next();
builder.add(measure, DoubleValue.of(parseDouble(obs.getMeasureValue(measure.getName()))));
TemporalAccessor holder = parser.getValue().queryFrom(parser.getKey().parse(obs.getDimensionValue()));
ScalarValue<?, ?, ?, ?> value;
Expand Down
4 changes: 4 additions & 0 deletions vtl-envs/vtl-spark/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@
<groupId>org.codehaus.mojo</groupId>
<artifactId>flatten-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import static it.bancaditalia.oss.vtl.impl.environment.spark.SparkEnvironment.LineageSparkUDT;
import static it.bancaditalia.oss.vtl.impl.environment.spark.SparkUtils.createStructFromComponents;
import static it.bancaditalia.oss.vtl.impl.environment.spark.SparkUtils.getScalarFor;
import static it.bancaditalia.oss.vtl.impl.environment.spark.SparkUtils.sorter;
import static java.util.stream.Collectors.joining;

import java.io.Serializable;
Expand Down Expand Up @@ -70,7 +69,7 @@ public DataPointEncoder(Set<? extends DataStructureComponent<?, ?, ?>> dataStruc
{
structure = dataStructure instanceof DataSetMetadata ? (DataSetMetadata) dataStructure : new DataStructureBuilder(dataStructure).build();
components = structure.toArray(new DataStructureComponent<?, ?, ?>[structure.size()]);
Arrays.sort(components, SparkUtils::sorter);
Arrays.sort(components, DataStructureComponent::byNameAndRole);
List<StructField> fields = new ArrayList<>(createStructFromComponents(components));
StructType schemaNoLineage = new StructType(fields.toArray(new StructField[components.length]));
rowEncoderNoLineage = Encoders.row(schemaNoLineage);
Expand Down Expand Up @@ -249,7 +248,7 @@ public DataPoint combine(DataPoint other, SerBiFunction<DataPoint, DataPoint, Li
while (j < dpo.comps.length && dpo.comps[j] == null)
j++;

int compare = i < comps.length ? j < dpo.comps.length ? sorter(comps[i], dpo.comps[j]) : -1 : 1;
int compare = i < comps.length ? j < dpo.comps.length ? DataStructureComponent.byNameAndRole(comps[i], dpo.comps[j]) : -1 : 1;
if (compare < 0)
{
comps2[k] = comps[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
import static it.bancaditalia.oss.vtl.impl.environment.spark.SparkUtils.getMetadataFor;
import static it.bancaditalia.oss.vtl.impl.environment.spark.SparkUtils.getNamesFromComponents;
import static it.bancaditalia.oss.vtl.impl.environment.spark.SparkUtils.getScalarFor;
import static it.bancaditalia.oss.vtl.impl.environment.spark.SparkUtils.sorter;
import static it.bancaditalia.oss.vtl.impl.types.dataset.DataPointBuilder.toDataPoint;
import static it.bancaditalia.oss.vtl.model.data.DataStructureComponent.byName;
import static it.bancaditalia.oss.vtl.model.transform.analytic.LimitCriterion.LimitDirection.PRECEDING;
import static it.bancaditalia.oss.vtl.model.transform.analytic.SortCriterion.SortingMethod.ASC;
import static it.bancaditalia.oss.vtl.model.transform.analytic.WindowCriterion.LimitType.RANGE;
Expand Down Expand Up @@ -107,7 +105,6 @@
import it.bancaditalia.oss.vtl.impl.types.lineage.LineageExternal;
import it.bancaditalia.oss.vtl.impl.types.lineage.LineageNode;
import it.bancaditalia.oss.vtl.model.data.ComponentRole.Identifier;
import it.bancaditalia.oss.vtl.model.data.ComponentRole.Measure;
import it.bancaditalia.oss.vtl.model.data.ComponentRole.NonIdentifier;
import it.bancaditalia.oss.vtl.model.data.DataPoint;
import it.bancaditalia.oss.vtl.model.data.DataSet;
Expand Down Expand Up @@ -192,7 +189,7 @@ public DataSet membership(String alias)
final DataSetMetadata membershipStructure = getMetadata().membership(alias);
LOGGER.debug("Creating dataset by membership on {} from {} to {}", alias, getMetadata(), membershipStructure);

DataStructureComponent<? extends NonIdentifier, ?, ?> membershipMeasure = membershipStructure.getComponents(Measure.class).iterator().next();
DataStructureComponent<? extends NonIdentifier, ?, ?> membershipMeasure = membershipStructure.getMeasures().iterator().next();

Dataset<Row> newDF = dataFrame;
if (!getMetadata().contains(membershipMeasure))
Expand Down Expand Up @@ -256,9 +253,9 @@ public DataSet subspace(Map<? extends DataStructureComponent<? extends Identifie
public DataSet mapKeepingKeys(DataSetMetadata metadata, SerFunction<? super DataPoint, ? extends Lineage> lineageOperator,
SerFunction<? super DataPoint, ? extends Map<? extends DataStructureComponent<?, ?, ?>, ? extends ScalarValue<?, ?, ?, ?>>> operator)
{
final Set<DataStructureComponent<Identifier, ?, ?>> originalIDs = getMetadata().getComponents(Identifier.class);
if (!metadata.getComponents(Identifier.class).containsAll(originalIDs))
throw new VTLInvariantIdentifiersException("map", originalIDs, metadata.getComponents(Identifier.class));
final Set<DataStructureComponent<Identifier, ?, ?>> originalIDs = getMetadata().getIDs();
if (!metadata.getIDs().containsAll(originalIDs))
throw new VTLInvariantIdentifiersException("map", originalIDs, metadata.getIDs());

LOGGER.trace("Creating dataset by mapping from {} to {}", getMetadata(), metadata);

Expand All @@ -270,7 +267,7 @@ public DataSet mapKeepingKeys(DataSetMetadata metadata, SerFunction<? super Data

// compute values
Object[] resultArray = map.entrySet().stream()
.sorted((a, b) -> sorter(a.getKey(), b.getKey()))
.sorted((a, b) -> DataStructureComponent.byNameAndRole(a.getKey(), b.getKey()))
.map(Entry::getValue)
.map(ScalarValue::get)
.collect(toArray(new Object[map.size() + 1]));
Expand Down Expand Up @@ -315,7 +312,7 @@ public DataSet flatmapKeepingKeys(DataSetMetadata metadata, SerFunction<? super
{
DataPointEncoder resultEncoder = new DataPointEncoder(metadata);
DataStructureComponent<?, ?, ?>[] comps = resultEncoder.components;
Set<DataStructureComponent<Identifier, ?, ?>> ids = getMetadata().getComponents(Identifier.class);
Set<DataStructureComponent<Identifier, ?, ?>> ids = getMetadata().getIDs();
StructType schema = resultEncoder.schema;

Dataset<Row> flattenedDf = dataFrame.flatMap((FlatMapFunction<Row, Row>) row -> {
Expand Down Expand Up @@ -350,7 +347,7 @@ public DataSet filteredMappedJoin(DataSetMetadata metadata, DataSet other, SerBi
{
SparkDataSet sparkOther = other instanceof SparkDataSet ? ((SparkDataSet) other) : new SparkDataSet(session, other.getMetadata(), other);

List<String> commonIDs = getComponents(Identifier.class).stream()
List<String> commonIDs = getMetadata().getIDs().stream()
.filter(other.getMetadata()::contains)
.map(DataStructureComponent::getName)
.collect(toList());
Expand Down Expand Up @@ -440,7 +437,7 @@ public <TT> DataSet analytic(SerFunction<DataPoint, Lineage> lineageOp,
// Sort by dest component
@SuppressWarnings("unchecked")
Entry<DataStructureComponent<?, ?, ?>, DataStructureComponent<?, ?, ?>>[] compArray = (Entry<DataStructureComponent<?, ?, ?>, DataStructureComponent<?, ?, ?>>[]) components.entrySet().toArray(new Entry<?, ?>[components.size()]);
Arrays.sort(compArray, (e1, e2) -> sorter(e1.getValue(), e2.getValue()));
Arrays.sort(compArray, (e1, e2) -> DataStructureComponent.byNameAndRole(e1.getValue(), e2.getValue()));

// Create the udafs to generate each dest component
Map<String, Column> destComponents = new HashMap<>();
Expand Down Expand Up @@ -569,11 +566,11 @@ public <A, T, TT> Stream<T> streamByKeys(Set<DataStructureComponent<Identifier,
{
@SuppressWarnings("unchecked")
DataStructureComponent<Identifier, ?, ?>[] sortedKeys = (DataStructureComponent<Identifier, ?, ?>[]) keys.stream()
.sorted(byName())
.sorted(DataStructureComponent::byName)
.collect(toArray(new DataStructureComponent<?, ?, ?>[keys.size()]));

Column[] groupingCols = keys.stream()
.sorted(byName())
.sorted(DataStructureComponent::byName)
.map(DataStructureComponent::getName)
.map(functions::col)
.collect(toArray(new Column[keys.size()]));
Expand All @@ -593,7 +590,7 @@ public <A, T, TT> Stream<T> streamByKeys(Set<DataStructureComponent<Identifier,

// case: supports decoding into a List<DataPoint> for fill_time_series
List<DataStructureComponent<?, ?, ?>> resultComponents = getMetadata().stream()
.sorted(SparkUtils::sorter)
.sorted(DataStructureComponent::byNameAndRole)
.collect(toList());

// Use kryo encoder hoping that the class has been registered beforehand
Expand Down Expand Up @@ -641,7 +638,7 @@ public DataSet union(SerFunction<DataPoint, Lineage> lineageOp, List<DataSet> ot
.get();

// remove duplicates and add lineage
Column[] ids = getColumnsFromComponents(getMetadata().getComponents(Identifier.class)).toArray(new Column[0]);
Column[] ids = getColumnsFromComponents(getMetadata().getIDs()).toArray(new Column[0]);
Column[] cols = getColumnsFromComponents(getMetadata()).toArray(new Column[getMetadata().size()]);
Column lineage = new Column(Literal.create(LineageSparkUDT.serialize(LineageExternal.of("Union")), LineageSparkUDT));
result = result.withColumn("__index", first("__index").over(partitionBy(ids).orderBy(result.col("__index"))))
Expand All @@ -656,7 +653,7 @@ public DataSet union(SerFunction<DataPoint, Lineage> lineageOp, List<DataSet> ot
public DataSet setDiff(DataSet other)
{
SparkDataSet sparkOther = other instanceof SparkDataSet ? ((SparkDataSet) other) : new SparkDataSet(session, other.getMetadata(), other);
List<String> ids = getMetadata().getComponents(Identifier.class).stream().map(DataStructureComponent::getName).collect(toList());
List<String> ids = getMetadata().getIDs().stream().map(DataStructureComponent::getName).collect(toList());
Dataset<Row> result = dataFrame.join(sparkOther.dataFrame, asScala((Iterable<String>) ids).toSeq(), "leftanti");

Column[] cols = getColumnsFromComponents(getMetadata()).toArray(new Column[getMetadata().size() + 1]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
import it.bancaditalia.oss.vtl.impl.types.lineage.LineageImpl;
import it.bancaditalia.oss.vtl.impl.types.lineage.LineageNode;
import it.bancaditalia.oss.vtl.impl.types.lineage.LineageSet;
import it.bancaditalia.oss.vtl.model.data.ComponentRole.Identifier;
import it.bancaditalia.oss.vtl.model.data.DataSet;
import it.bancaditalia.oss.vtl.model.data.DataSetMetadata;
import it.bancaditalia.oss.vtl.model.data.DataStructureComponent;
Expand Down Expand Up @@ -311,7 +310,7 @@ else if (field.dataType() instanceof IntegerType)
Column[] converters = Arrays.stream(normalizedNames, 0, normalizedNames.length)
.map(structure::getComponent)
.map(Optional::get)
.sorted(SparkUtils::sorter)
.sorted(DataStructureComponent::byNameAndRole)
.map(c -> udf(repr -> mapValue(c, repr.toString(), masks.get(c)).get(), types.get(c))
.apply(sourceDataFrame.col(newToOldNames.get(c.getName())))
.as(c.getName(), getMetadataFor(c)))
Expand All @@ -323,7 +322,7 @@ else if (field.dataType() instanceof IntegerType)
converters[converters.length - 1] = lit(serializedLineage).alias("$lineage$");

Dataset<Row> converted = sourceDataFrame.select(converters);
Column[] ids = getColumnsFromComponents(structure.getComponents(Identifier.class)).toArray(new Column[0]);
Column[] ids = getColumnsFromComponents(structure.getIDs()).toArray(new Column[0]);
return new SparkDataSet(session, structure, converted.repartition(ids));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;

import org.apache.spark.sql.Column;
Expand Down Expand Up @@ -220,31 +218,11 @@ public static List<Column> getColumnsFromComponents(Collection<? extends DataStr
public static <F> List<F> structHelper(Stream<? extends DataStructureComponent<?, ?, ?>> stream, SerFunction<? super DataStructureComponent<?, ?, ?>, F> mapper)
{
return stream
.sorted(SparkUtils::sorter)
.sorted(DataStructureComponent::byNameAndRole)
.map(mapper)
.collect(toList());
}

public static int sorter(DataStructureComponent<?, ?, ?> c1, DataStructureComponent<?, ?, ?> c2)
{
if (c1.is(Attribute.class) && !c2.is(Attribute.class))
return 1;
else if (c1.is(Identifier.class) && !c2.is(Identifier.class))
return -1;
else if (c1.is(Measure.class) && c2.is(Identifier.class))
return 1;
else if (c1.is(Measure.class) && c2.is(Attribute.class))
return -1;

String n1 = c1.getName(), n2 = c2.getName();
Pattern pattern = Pattern.compile("^(.+?)(\\d+)$");
Matcher m1 = pattern.matcher(n1), m2 = pattern.matcher(n2);
if (m1.find() && m2.find() && m1.group(1).equals(m2.group(1)))
return Integer.compare(Integer.parseInt(m1.group(2)), Integer.parseInt(m2.group(2)));
else
return n1.compareTo(n2);
}

private SparkUtils()
{
}
Expand Down

0 comments on commit 1ca985e

Please sign in to comment.