Skip to content

Commit

Permalink
Create Collection.SerializationWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Glavo committed Jun 6, 2024
1 parent 953fa05 commit 12be657
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 39 deletions.
63 changes: 63 additions & 0 deletions kala-collection/src/main/java/kala/collection/Collection.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.UnmodifiableView;

import java.io.*;
import java.util.Iterator;
import java.util.function.*;

public interface Collection<@Covariant E> extends CollectionLike<E>, AnyCollection<E> {
Expand Down Expand Up @@ -177,4 +179,65 @@ static <E> Collection<E> narrow(Collection<? extends E> collection) {
default @NotNull ImmutableCollection<E> distinct() {
return distinct(ImmutableSeq.factory());
}

final class SerializationWrapper<E, C extends Collection<E>> implements Serializable {
@Serial
private static final long serialVersionUID = 0L;

private final CollectionFactory<E, ?, C> factory;
private transient C value;

public SerializationWrapper(CollectionFactory<E, ?, C> factory, C value) {
this.factory = factory;
this.value = value;
}

@SuppressWarnings("unchecked")
private static <E, B, C extends Collection<E>> C readObjectImpl(ObjectInputStream input, CollectionFactory<E, B, C> factory, int size) throws IOException, ClassNotFoundException {
if (size < 0) {
throw new IOException("Invalid size: " + size);
}

if (size == 0) {
return factory.empty();
}

B builder = factory.newBuilder(size);
for (int i = 0; i < size; i++) {
factory.addToBuilder(builder, (E) input.readObject());
}
return factory.build(builder);
}

@Serial
private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException {
input.defaultReadObject();
value = readObjectImpl(input, factory, input.readInt());
}

@Serial
private void writeObject(ObjectOutputStream output) throws IOException {
output.defaultWriteObject();
int size = value.size();

output.writeInt(size);
if (size == 0) {
return;
}

Iterator<E> iterator = value.iterator();
for (int i = 0; i < size; i++) {
if (!iterator.hasNext()) {
throw new IOException("No more elements");
}

output.writeObject(iterator.next());
}
}

@Serial
private Object readResolve() {
return value;
}
}
}
2 changes: 1 addition & 1 deletion kala-collection/src/main/java/kala/collection/Seq.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import kala.collection.base.Iterators;
import kala.collection.base.OrderedTraversable;
import kala.collection.factory.CollectionFactory;
import kala.collection.immutable.ImmutableCollection;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.internal.convert.AsJavaConvert;
import kala.collection.internal.convert.FromJavaConvert;
Expand Down Expand Up @@ -419,4 +418,5 @@ static boolean equals(@NotNull Seq<?> seq1, @NotNull AnySeq<?> seq2) {
default @NotNull ImmutableSeq<E> distinct() {
return view().distinct().toImmutableSeq();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public static <E> ImmutableHashSet<E> narrow(ImmutableHashSet<? extends E> set)

@Serial
private Object writeReplace() {
return new SerializationWrapper<>(this);
return new SerializationWrapper<>(factory(), this);
}

private static final class Builder<E> {
Expand Down Expand Up @@ -193,7 +193,10 @@ ImmutableHashSet<E> build() {
}
}

private static final class Factory<E> implements CollectionFactory<E, Builder<E>, ImmutableHashSet<E>> {
private static final class Factory<E> implements CollectionFactory<E, Builder<E>, ImmutableHashSet<E>>, Serializable {

@Serial
private static final long serialVersionUID = 0L;

@Override
public Builder<E> newBuilder() {
Expand All @@ -214,43 +217,11 @@ public void addToBuilder(@NotNull Builder<E> builder, E value) {
public Builder<E> mergeBuilder(@NotNull Builder<E> builder1, @NotNull Builder<E> builder2) {
return builder1.merge(builder2);
}
}

private static final class SerializationWrapper<E> implements Externalizable {
private ImmutableHashSet<E> value;

public SerializationWrapper() {
}

SerializationWrapper(ImmutableHashSet<E> value) {
this.value = value;
}

@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(value.size());
for (E e : value) {
out.writeObject(e);
}
}

@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
assert value == null;

HashSet<E> set = new HashSet<>();

int len = in.readInt();
for (int i = 0; i < len; i++) {
set.add((E) in.readObject());
}

value = set.isEmpty() ? ImmutableHashSet.empty() : new ImmutableHashSet<>(set);
}

@Serial
private Object readResolve() {
return value;
return factory();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ public int size() {
}
}

static final class Map2<K, V> extends MapN<K, V> {
static final class Map2<K, V> extends MapN<K, V> implements Serializable {
@Serial
private static final long serialVersionUID = 0L;

private final K k0;
private final V v0;
Expand Down Expand Up @@ -111,7 +113,9 @@ public int size() {
}
}

static final class Map3<K, V> extends MapN<K, V> {
static final class Map3<K, V> extends MapN<K, V> implements Serializable {
@Serial
private static final long serialVersionUID = 0L;

private final K k0;
private final V v0;
Expand Down
32 changes: 32 additions & 0 deletions src/test/template/kala/collection/SetTestTemplate.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
package kala.collection;

import kala.SerializationUtils;
import kala.collection.factory.CollectionFactory;
import org.junit.jupiter.api.Test;

import java.io.*;

import static org.junit.jupiter.api.Assertions.assertEquals;

@SuppressWarnings("unchecked")
public interface SetTestTemplate extends SetLikeTestTemplate, CollectionTestTemplate {

@Override
Expand All @@ -20,4 +27,29 @@ public interface SetTestTemplate extends SetLikeTestTemplate, CollectionTestTemp
default void factoryTest() {

}

@Test
default void serializationTest() throws IOException, ClassNotFoundException {
try {
for (Integer[] data : data1()) {
Collection<?> c = factory().from(data);
ByteArrayOutputStream out = new ByteArrayOutputStream(4 * 128);
new ObjectOutputStream(out).writeObject(c);
byte[] buffer = out.toByteArray();
ByteArrayInputStream in = new ByteArrayInputStream(buffer);
var obj = (Set<Integer>) new ObjectInputStream(in).readObject();

assertEquals(c, obj);
}
} catch (NotSerializableException ignored) {
}

assertEquals(of(), SerializationUtils.writeAndRead(of()));
assertEquals(of(0), SerializationUtils.writeAndRead(of(0)));
assertEquals(of(0, 1, 2), SerializationUtils.writeAndRead(of(0, 1, 2)));

for (String[] data : data1s()) {
assertEquals(from(data), SerializationUtils.writeAndRead(from(data)));
}
}
}

0 comments on commit 12be657

Please sign in to comment.