Skip to content

Commit

Permalink
improve small errors on shm communication
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 2, 2024
1 parent 6ad58b1 commit d8c0725
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ private static void buildFromTensorUByte(Tensor<UInt8> tensor, String memoryName
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.writeTo(buff);
if (PlatformDetection.isWindows()) shma.close();
}
Expand All @@ -106,7 +106,7 @@ private static void buildFromTensorInt(Tensor<Integer> tensor, String memoryName
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.writeTo(buff);
if (PlatformDetection.isWindows()) shma.close();
}
Expand All @@ -119,7 +119,7 @@ private static void buildFromTensorFloat(Tensor<Float> tensor, String memoryName
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.writeTo(buff);
if (PlatformDetection.isWindows()) shma.close();
}
Expand All @@ -132,7 +132,7 @@ private static void buildFromTensorDouble(Tensor<Double> tensor, String memoryNa
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.writeTo(buff);
if (PlatformDetection.isWindows()) shma.close();
}
Expand All @@ -146,7 +146,7 @@ private static void buildFromTensorLong(Tensor<Long> tensor, String memoryName)


SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.writeTo(buff);
if (PlatformDetection.isWindows()) shma.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ private static Tensor<Integer> buildInt(SharedMemoryArray tensor)
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
IntBuffer intBuff = buff.asIntBuffer();
int[] intArray = new int[intBuff.capacity()];
intBuff.get(intArray);
Tensor<Integer> ndarray = Tensor.create(ogShape, intBuff);
return ndarray;
}
Expand All @@ -122,8 +120,6 @@ private static Tensor<Long> buildLong(SharedMemoryArray tensor)
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
LongBuffer longBuff = buff.asLongBuffer();
long[] longArray = new long[longBuff.capacity()];
longBuff.get(longArray);
Tensor<Long> ndarray = Tensor.create(ogShape, longBuff);
return ndarray;
}
Expand All @@ -139,8 +135,6 @@ private static Tensor<Float> buildFloat(SharedMemoryArray tensor)
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
FloatBuffer floatBuff = buff.asFloatBuffer();
float[] floatArray = new float[floatBuff.capacity()];
floatBuff.get(floatArray);
Tensor<Float> ndarray = Tensor.create(ogShape, floatBuff);
return ndarray;
}
Expand All @@ -156,8 +150,6 @@ private static Tensor<Double> buildDouble(SharedMemoryArray tensor)
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
DoubleBuffer doubleBuff = buff.asDoubleBuffer();
double[] doubleArray = new double[doubleBuff.capacity()];
doubleBuff.get(doubleArray);
Tensor<Double> ndarray = Tensor.create(ogShape, doubleBuff);
return ndarray;
}
Expand Down

0 comments on commit d8c0725

Please sign in to comment.