Skip to content

Commit

Permalink
No issue: Disable tests on windows due to https://bugs.openjdk.org/br…
Browse files Browse the repository at this point in the history
  • Loading branch information
reckart committed Aug 3, 2024
1 parent 105c7c7 commit bf32211
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 91 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/maven.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
os: [ubuntu-latest, windows-latest]
jdk: [17]

runs-on: ${{ matrix.os }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package org.dkpro.core.api.embeddings.binary;

import static java.nio.channels.FileChannel.MapMode.READ_ONLY;

import java.io.DataInput;
import java.io.DataOutputStream;
import java.io.File;
Expand All @@ -23,7 +25,6 @@
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Locale;
Expand All @@ -38,7 +39,7 @@
* @see BinaryWordVectorUtils
*/
public class BinaryVectorizer
implements Vectorizer
implements Vectorizer
{
private static final Logger LOG = LoggerFactory.getLogger(BinaryVectorizer.class);
private final String[] words;
Expand All @@ -51,7 +52,7 @@ public class BinaryVectorizer

private BinaryVectorizer(Header aHeader, RandomAccessFile aFile, String[] aWords,
long vectorStartOffset, float[] aUnk)
throws IOException
throws IOException
{
file = aFile;
header = aHeader;
Expand All @@ -71,20 +72,21 @@ private BinaryVectorizer(Header aHeader, RandomAccessFile aFile, String[] aWords
}

parts = new FloatBuffer[neededPartitions];
FileChannel channel = aFile.getChannel();
var channel = aFile.getChannel();
for (int i = 0; i < neededPartitions; i++) {
long start = vectorStartOffset + ((long) i * maxPartitionSizeBytes);
long length = maxPartitionSizeBytes;
if (i == neededPartitions - 1) {
length = (aWords.length % maxVectorsPerPartition) * header.getVectorLength()
* Float.BYTES;
}
parts[i] = channel.map(FileChannel.MapMode.READ_ONLY, start, length).asFloatBuffer();
parts[i] = channel.map(READ_ONLY, start, length).asFloatBuffer();
}
}

@Override
public void close() throws IOException {
public void close() throws IOException
{
if (file != null) {
file.close();
}
Expand All @@ -93,12 +95,13 @@ public void close() throws IOException {
/**
* Load a binary embeddings file and return a new {@link BinaryVectorizer} object.
*
* @param f a {@link File}
* @param f
* a {@link File}
* @return a new {@link BinaryVectorizer}
* @throws IOException if an I/O error occurs
* @throws IOException
* if an I/O error occurs
*/
public static BinaryVectorizer load(File f)
throws IOException
public static BinaryVectorizer load(File f) throws IOException
{
var file = new RandomAccessFile(f, "rw");

Expand All @@ -115,7 +118,8 @@ public static BinaryVectorizer load(File f)
// Load UNK vector
byte[] buffer = new byte[header.getVectorLength() * Float.BYTES];
file.readFully(buffer);
ByteBuffer byteBuffer = ByteBuffer.wrap(buffer);

var byteBuffer = ByteBuffer.wrap(buffer);
float[] unk = new float[header.getVectorLength()];
for (int i = 0; i < unk.length; i++) {
unk[i] = byteBuffer.getFloat(i * Float.BYTES);
Expand All @@ -126,8 +130,8 @@ public static BinaryVectorizer load(File f)
return new BinaryVectorizer(header, file, words, offset, unk);
}

@Override public float[] vectorize(String aWord)
throws IOException
@Override
public float[] vectorize(String aWord) throws IOException
{
String word = aWord;
if (header.isCaseless()) {
Expand Down Expand Up @@ -156,7 +160,8 @@ public static BinaryVectorizer load(File f)
return vector;
}

@Override public boolean contains(String aWord)
@Override
public boolean contains(String aWord)
{
String word = aWord;
if (header.isCaseless()) {
Expand All @@ -166,22 +171,26 @@ public static BinaryVectorizer load(File f)
return Arrays.binarySearch(words, word) >= 0;
}

@Override public float[] unknownVector()
@Override
public float[] unknownVector()
{
return unknownVector;
}

@Override public int dimensions()
@Override
public int dimensions()
{
return header.getVectorLength();
}

@Override public int size()
@Override
public int size()
{
return header.getWordCount();
}

@Override public boolean isCaseless()
@Override
public boolean isCaseless()
{
return header.isCaseless();
}
Expand All @@ -195,8 +204,7 @@ static class Header
private boolean caseless;
private String locale;

public static Header read(DataInput aInput)
throws IOException
public static Header read(DataInput aInput) throws IOException
{
byte[] magicBytes = new byte[MAGIC.length()];
aInput.readFully(magicBytes);
Expand Down Expand Up @@ -273,8 +281,7 @@ public void setVectorLength(int vectorLength)
this.vectorLength = vectorLength;
}

public void write(OutputStream aOutput)
throws IOException
public void write(OutputStream aOutput) throws IOException
{
DataOutputStream out = new DataOutputStream(aOutput);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.Locale;
import java.util.Map;

Expand All @@ -48,13 +47,16 @@ public class BinaryWordVectorUtils
* Write a map of token embeddings into binary format. Uses the default locale {@link Locale#US}
* and assume case-sensitivity iff there is any token containing an uppercase letter.
*
* @param vectors a {@code Map<String, float[]>} holding all tokens with embeddings
* @param binaryTarget the target file {@link File}
* @throws IOException if an I/O error occurs
* @param vectors
* a {@code Map<String, float[]>} holding all tokens with embeddings
* @param binaryTarget
* the target file {@link File}
* @throws IOException
* if an I/O error occurs
* @see #convertWordVectorsToBinary(Map, boolean, Locale, File)
*/
public static void convertWordVectorsToBinary(Map<String, float[]> vectors, File binaryTarget)
throws IOException
throws IOException
{
boolean caseless = vectors.keySet().stream()
.allMatch(token -> token.equals(token.toLowerCase()));
Expand All @@ -64,15 +66,20 @@ public static void convertWordVectorsToBinary(Map<String, float[]> vectors, File
/**
* Write a map of token embeddings into binary format.
*
* @param vectors a {@code Map<String, float[]>} holding all tokens with embeddings
* @param aCaseless if true, tokens are expected to be caseless
* @param aLocale the {@link Locale}
* @param binaryTarget the target file {@link File}
* @throws IOException if an I/O error occurs
* @param vectors
* a {@code Map<String, float[]>} holding all tokens with embeddings
* @param aCaseless
* if true, tokens are expected to be caseless
* @param aLocale
* the {@link Locale}
* @param binaryTarget
* the target file {@link File}
* @throws IOException
* if an I/O error occurs
*/
public static void convertWordVectorsToBinary(Map<String, float[]> vectors, boolean aCaseless,
Locale aLocale, File binaryTarget)
throws IOException
throws IOException
{
if (vectors.isEmpty()) {
throw new IllegalArgumentException("Word embeddings map must not be empty.");
Expand All @@ -82,47 +89,44 @@ public static void convertWordVectorsToBinary(Map<String, float[]> vectors, bool
assert vectors.values().stream().allMatch(v -> v.length == vectorLength);

Header header = prepareHeader(aCaseless, aLocale, vectors.size(), vectorLength);
DataOutputStream output = new DataOutputStream(
new BufferedOutputStream(new FileOutputStream(binaryTarget)));
header.write(output);
try (var output = new DataOutputStream(
new BufferedOutputStream(new FileOutputStream(binaryTarget)))) {
header.write(output);

LOG.info("Sorting data...");
String[] words = vectors.keySet().stream()
.sorted()
.toArray(String[]::new);
LOG.info("Sorting data...");
String[] words = vectors.keySet().stream().sorted().toArray(String[]::new);

LOG.info("Writing strings...");
for (String word : words) {
output.writeUTF(word);
}
LOG.info("Writing strings...");
for (String word : words) {
output.writeUTF(word);
}

LOG.info("Writing UNK vector...");
{
float[] vector = VectorizerUtils.randomVector(header.getVectorLength());
writeVector(output, vector);
}
LOG.info("Writing UNK vector...");
{
float[] vector = VectorizerUtils.randomVector(header.getVectorLength());
writeVector(output, vector);
}

LOG.info("Writing vectors...");
for (String word : words) {
float[] vector = vectors.get(word);
writeVector(output, vector);
LOG.info("Writing vectors...");
for (String word : words) {
float[] vector = vectors.get(word);
writeVector(output, vector);
}
}
output.close();
}

private static void writeVector(DataOutputStream output, float[] vector)
throws IOException
private static void writeVector(DataOutputStream output, float[] vector) throws IOException
{
ByteBuffer buffer = ByteBuffer.allocate(vector.length * Float.BYTES);
FloatBuffer floatBuffer = buffer.asFloatBuffer();
var buffer = ByteBuffer.allocate(vector.length * Float.BYTES);
var floatBuffer = buffer.asFloatBuffer();
floatBuffer.put(vector);
output.write(buffer.array());
}

private static Header prepareHeader(boolean aCaseless,
Locale aLocale, int wordCount, int vectorLength)
private static Header prepareHeader(boolean aCaseless, Locale aLocale, int wordCount,
int vectorLength)
{
Header header = new Header();
var header = new Header();
header.setVersion(1);
header.setWordCount(wordCount);
header.setVectorLength(vectorLength);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,16 @@
import org.dkpro.core.api.embeddings.VectorizerUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnOs;
import org.junit.jupiter.api.condition.OS;
import org.junit.jupiter.api.io.TempDir;

@DisabledOnOs(value = OS.WINDOWS, //
disabledReason = "mmapped buffers cannot be unmapped explicitly, so we cannot delete the temp dir on Windows")
public class BinaryWordVectorUtilsTest
{
private @TempDir File tempDir;

// TODO: test for very large data (>2GB should be chunked)
private Map<String, float[]> vectors;

Expand All @@ -50,30 +54,28 @@ public void setUp()
}

@Test
public void testConvertWordVectorsToBinary()
throws Exception
public void testConvertWordVectorsToBinary() throws Exception
{
var binaryTarget = writeBinaryFile(vectors);

try (var vec = BinaryVectorizer.load(binaryTarget)) {
assertThat(vec.contains("t1")).isTrue();
assertThat(vec.contains("t2")).isTrue();
assertThat(vec.dimensions()).isEqualTo(3);
assertThat(vec.size()).isEqualTo(2);
assertThat(vec.isCaseless()).isTrue();

for (var word : vectors.keySet()) {
var orig = vectors.get(word);
var conv = vec.vectorize(word);

assertThat(conv).containsExactly(orig);
}
}
}

@Test
public void testConvertWordVectorsToBinaryCaseSensitive()
throws Exception
public void testConvertWordVectorsToBinaryCaseSensitive() throws Exception
{
vectors.put("T1", new float[] { 0.1f, 0.2f, 0.3f });
var binaryTarget = writeBinaryFile(vectors);
Expand All @@ -97,31 +99,30 @@ public void testConvertWordVectorsToBinaryCaseSensitive()
}

@Test
public void testRandomVector()
throws IOException
public void testRandomVector() throws IOException
{
var binaryTarget = writeBinaryFile(vectors);

try (var vec = BinaryVectorizer.load(binaryTarget)) {
var randVector = VectorizerUtils.randomVector(3);

var unk1 = vec.vectorize("unk1");
var unk2 = vec.vectorize("unk2");
assertTrue(Arrays.equals(randVector, unk1));
assertTrue(Arrays.equals(randVector, unk2));
assertTrue(
Arrays.equals(unk1, unk2), "Vectors or unknown words should always be the same.");
assertTrue(Arrays.equals(unk1, unk2),
"Vectors or unknown words should always be the same.");
}
}

/**
* Write a binary vectors file to a testContext-dependent location.
*
* @return the binary vectors {@link File}
* @throws IOException if an I/O error occurs
* @throws IOException
* if an I/O error occurs
*/
private File writeBinaryFile(Map<String, float[]> vectors)
throws IOException
private File writeBinaryFile(Map<String, float[]> vectors) throws IOException
{
var binaryTarget = new File(tempDir, "binaryTarget");
convertWordVectorsToBinary(vectors, binaryTarget);
Expand Down
Loading

0 comments on commit bf32211

Please sign in to comment.