diff --git a/kala-collection/src/main/java/kala/collection/Collection.java b/kala-collection/src/main/java/kala/collection/Collection.java index 9129c54c..324d70cb 100644 --- a/kala-collection/src/main/java/kala/collection/Collection.java +++ b/kala-collection/src/main/java/kala/collection/Collection.java @@ -16,6 +16,8 @@ import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.UnmodifiableView; +import java.io.*; +import java.util.Iterator; import java.util.function.*; public interface Collection<@Covariant E> extends CollectionLike, AnyCollection { @@ -177,4 +179,65 @@ static Collection narrow(Collection collection) { default @NotNull ImmutableCollection distinct() { return distinct(ImmutableSeq.factory()); } + + final class SerializationWrapper> implements Serializable { + @Serial + private static final long serialVersionUID = 0L; + + private final CollectionFactory factory; + private transient C value; + + public SerializationWrapper(CollectionFactory factory, C value) { + this.factory = factory; + this.value = value; + } + + @SuppressWarnings("unchecked") + private static > C readObjectImpl(ObjectInputStream input, CollectionFactory factory, int size) throws IOException, ClassNotFoundException { + if (size < 0) { + throw new IOException("Invalid size: " + size); + } + + if (size == 0) { + return factory.empty(); + } + + B builder = factory.newBuilder(size); + for (int i = 0; i < size; i++) { + factory.addToBuilder(builder, (E) input.readObject()); + } + return factory.build(builder); + } + + @Serial + private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException { + input.defaultReadObject(); + value = readObjectImpl(input, factory, input.readInt()); + } + + @Serial + private void writeObject(ObjectOutputStream output) throws IOException { + output.defaultWriteObject(); + int size = value.size(); + + output.writeInt(size); + if (size == 0) { + return; + } + + Iterator iterator = value.iterator(); + for (int i = 0; i < size; i++) { + if (!iterator.hasNext()) { + throw new IOException("No more elements"); + } + + output.writeObject(iterator.next()); + } + } + + @Serial + private Object readResolve() { + return value; + } + } } diff --git a/kala-collection/src/main/java/kala/collection/Seq.java b/kala-collection/src/main/java/kala/collection/Seq.java index 6afba156..d340a927 100644 --- a/kala-collection/src/main/java/kala/collection/Seq.java +++ b/kala-collection/src/main/java/kala/collection/Seq.java @@ -5,7 +5,6 @@ import kala.collection.base.Iterators; import kala.collection.base.OrderedTraversable; import kala.collection.factory.CollectionFactory; -import kala.collection.immutable.ImmutableCollection; import kala.collection.immutable.ImmutableSeq; import kala.collection.internal.convert.AsJavaConvert; import kala.collection.internal.convert.FromJavaConvert; @@ -419,4 +418,5 @@ static boolean equals(@NotNull Seq seq1, @NotNull AnySeq seq2) { default @NotNull ImmutableSeq distinct() { return view().distinct().toImmutableSeq(); } + } diff --git a/kala-collection/src/main/java/kala/collection/immutable/ImmutableHashSet.java b/kala-collection/src/main/java/kala/collection/immutable/ImmutableHashSet.java index 53406828..54054145 100644 --- a/kala-collection/src/main/java/kala/collection/immutable/ImmutableHashSet.java +++ b/kala-collection/src/main/java/kala/collection/immutable/ImmutableHashSet.java @@ -162,7 +162,7 @@ public static ImmutableHashSet narrow(ImmutableHashSet set) @Serial private Object writeReplace() { - return new SerializationWrapper<>(this); + return new SerializationWrapper<>(factory(), this); } private static final class Builder { @@ -193,7 +193,10 @@ ImmutableHashSet build() { } } - private static final class Factory implements CollectionFactory, ImmutableHashSet> { + private static final class Factory implements CollectionFactory, ImmutableHashSet>, Serializable { + + @Serial + private static final long serialVersionUID = 0L; @Override public Builder newBuilder() { @@ -214,43 +217,11 @@ public void addToBuilder(@NotNull Builder builder, E value) { public Builder mergeBuilder(@NotNull Builder builder1, @NotNull Builder builder2) { return builder1.merge(builder2); } - } - - private static final class SerializationWrapper implements Externalizable { - private ImmutableHashSet value; - - public SerializationWrapper() { - } - - SerializationWrapper(ImmutableHashSet value) { - this.value = value; - } - - @Override - public void writeExternal(ObjectOutput out) throws IOException { - out.writeInt(value.size()); - for (E e : value) { - out.writeObject(e); - } - } - - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - assert value == null; - - HashSet set = new HashSet<>(); - - int len = in.readInt(); - for (int i = 0; i < len; i++) { - set.add((E) in.readObject()); - } - - value = set.isEmpty() ? ImmutableHashSet.empty() : new ImmutableHashSet<>(set); - } @Serial private Object readResolve() { - return value; + return factory(); } } + } diff --git a/kala-collection/src/main/java/kala/collection/immutable/ImmutableMaps.java b/kala-collection/src/main/java/kala/collection/immutable/ImmutableMaps.java index 5f0332bd..4db544a2 100644 --- a/kala-collection/src/main/java/kala/collection/immutable/ImmutableMaps.java +++ b/kala-collection/src/main/java/kala/collection/immutable/ImmutableMaps.java @@ -83,7 +83,9 @@ public int size() { } } - static final class Map2 extends MapN { + static final class Map2 extends MapN implements Serializable { + @Serial + private static final long serialVersionUID = 0L; private final K k0; private final V v0; @@ -111,7 +113,9 @@ public int size() { } } - static final class Map3 extends MapN { + static final class Map3 extends MapN implements Serializable { + @Serial + private static final long serialVersionUID = 0L; private final K k0; private final V v0; diff --git a/src/test/template/kala/collection/SetTestTemplate.java b/src/test/template/kala/collection/SetTestTemplate.java index 3e723e4e..821bbfb1 100644 --- a/src/test/template/kala/collection/SetTestTemplate.java +++ b/src/test/template/kala/collection/SetTestTemplate.java @@ -1,7 +1,14 @@ package kala.collection; +import kala.SerializationUtils; import kala.collection.factory.CollectionFactory; +import org.junit.jupiter.api.Test; +import java.io.*; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@SuppressWarnings("unchecked") public interface SetTestTemplate extends SetLikeTestTemplate, CollectionTestTemplate { @Override @@ -20,4 +27,29 @@ public interface SetTestTemplate extends SetLikeTestTemplate, CollectionTestTemp default void factoryTest() { } + + @Test + default void serializationTest() throws IOException, ClassNotFoundException { + try { + for (Integer[] data : data1()) { + Collection c = factory().from(data); + ByteArrayOutputStream out = new ByteArrayOutputStream(4 * 128); + new ObjectOutputStream(out).writeObject(c); + byte[] buffer = out.toByteArray(); + ByteArrayInputStream in = new ByteArrayInputStream(buffer); + var obj = (Set) new ObjectInputStream(in).readObject(); + + assertEquals(c, obj); + } + } catch (NotSerializableException ignored) { + } + + assertEquals(of(), SerializationUtils.writeAndRead(of())); + assertEquals(of(0), SerializationUtils.writeAndRead(of(0))); + assertEquals(of(0, 1, 2), SerializationUtils.writeAndRead(of(0, 1, 2))); + + for (String[] data : data1s()) { + assertEquals(from(data), SerializationUtils.writeAndRead(from(data))); + } + } }