Skip to content
This repository has been archived by the owner on Jun 21, 2022. It is now read-only.

Commit

Permalink
Merge pull request #97 from scikit-hep/issue-92
Browse files Browse the repository at this point in the history
Fixes #92: all awkward arrays now have an nbytes parameter
  • Loading branch information
jpivarski authored Mar 9, 2019
2 parents 1681251 + b083bd8 commit bd803ca
Show file tree
Hide file tree
Showing 18 changed files with 130 additions and 5 deletions.
4 changes: 4 additions & 0 deletions awkward/array/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def __bool__(self):
def size(self):
return len(self)

@property
def nbytes(self):
return self._getnbytes(set())

def tolist(self):
import awkward.array.table
out = []
Expand Down
7 changes: 7 additions & 0 deletions awkward/array/chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ def _gettype(self, seen):

return tpe

def _getnbytes(self, seen):
if id(self) in seen:
return 0
else:
seen.add(id(self))
return sum(x.nbytes if isinstance(x, self.numpy.ndarray) else x._getnbytes(seen) for x in self._chunks)

def __len__(self):
self.knowcounts()
return self.offsets[-1]
Expand Down
14 changes: 14 additions & 0 deletions awkward/array/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ def content(self, value):
self._content = self._util_toarray(value, self.DEFAULTTYPE)
self._isvalid = False

def _getnbytes(self, seen):
if id(self) in seen:
return 0
else:
seen.add(id(self))
return self._index.nbytes + (self._content.nbytes if isinstance(self._content, self.numpy.ndarray) else self._content._getnbytes(seen))

def __len__(self):
return len(self._index)

Expand Down Expand Up @@ -376,6 +383,13 @@ def default(self, value):
def _gettype(self, seen):
return awkward.type._fromarray(self._content, seen)

def _getnbytes(self, seen):
if id(self) in seen:
return 0
else:
seen.add(id(self))
return self._index.nbytes + (self._content.nbytes if isinstance(self._content, self.numpy.ndarray) else self._content._getnbytes(seen))

def __len__(self):
return self._length

Expand Down
10 changes: 10 additions & 0 deletions awkward/array/jagged.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,16 @@ def index(self):
out = self.numpy.arange(len(self._content), dtype=self.INDEXTYPE)
return self.copy(content=(out - out[self._starts[self.parents]]))

def _getnbytes(self, seen):
if id(self) in seen:
return 0
else:
seen.add(id(self))
if self.offsetsaliased(self._starts, self._stops):
return self._starts.base.nbytes + (self._content.nbytes if isinstance(self._content, self.numpy.ndarray) else self._content._getnbytes(seen))
else:
return self._starts.nbytes + self._stops.nbytes + (self._content.nbytes if isinstance(self._content, self.numpy.ndarray) else self._content._getnbytes(seen))

def __len__(self):
return len(self._starts)

Expand Down
7 changes: 7 additions & 0 deletions awkward/array/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ def maskedwhen(self):
def maskedwhen(self, value):
self._maskedwhen = bool(value)

def _getnbytes(self, seen):
if id(self) in seen:
return 0
else:
seen.add(id(self))
return self._mask.nbytes + (self._content.nbytes if isinstance(self._content, self.numpy.ndarray) else self._content._getnbytes(seen))

def __len__(self):
return len(self._mask)

Expand Down
7 changes: 7 additions & 0 deletions awkward/array/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ def kwargs(self, value):
raise TypeError("kwargs must be a dict")
self._kwargs = dict(value)

def _getnbytes(self, seen):
if id(self) in seen:
return 0
else:
seen.add(id(self))
return (self._content.nbytes if isinstance(self._content, self.numpy.ndarray) else self._content._getnbytes(seen))

def __len__(self):
return len(self._content)

Expand Down
7 changes: 7 additions & 0 deletions awkward/array/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,13 @@ def contents(self, value):
value[n] = self._util_toarray(value[n], self.DEFAULTTYPE)
self._contents = value

def _getnbytes(self, seen):
if id(self) in seen:
return 0
else:
seen.add(id(self))
return sum(x.nbytes if isinstance(x, self.numpy.ndarray) else x._getnbytes(seen) for x in self._contents.values())

def __len__(self):
return self._length()

Expand Down
7 changes: 7 additions & 0 deletions awkward/array/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,13 @@ def dtype(self):

return self._dtype

def _getnbytes(self, seen):
if id(self) in seen:
return 0
else:
seen.add(id(self))
return sum(x.nbytes if isinstance(x, self.numpy.ndarray) else x._getnbytes(seen) for x in self._contents)

def __len__(self):
return len(self._tags)

Expand Down
37 changes: 33 additions & 4 deletions awkward/array/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,20 @@ def __ne__(self, other):
def __getstate__(self):
raise RuntimeError("VirtualArray.TransientKeys are not unique across processes, and hence should not be serialized")

def __init__(self, generator, args=(), kwargs={}, cache=None, persistentkey=None, type=None, persistvirtual=True):
def __init__(self, generator, args=(), kwargs={}, cache=None, persistentkey=None, type=None, nbytes=None, persistvirtual=True):
self.generator = generator
self.args = args
self.kwargs = kwargs
self.cache = cache
self.persistentkey = persistentkey
self.type = type
self.nbytes = nbytes
self.persistvirtual = persistvirtual
self._array = None
self._setitem = None
self._delitem = None

def copy(self, generator=None, args=None, kwargs=None, cache=None, persistentkey=None, type=None, persistvirtual=None):
def copy(self, generator=None, args=None, kwargs=None, cache=None, persistentkey=None, type=None, nbytes=None, persistvirtual=None):
# FIXME: arguments through **kwargs because undef is different from None (None has meaning for some of them)
out = self.__class__.__new__(self.__class__)
out._generator = self._generator
Expand All @@ -78,6 +79,7 @@ def copy(self, generator=None, args=None, kwargs=None, cache=None, persistentkey
out._cache = self._cache
out._persistentkey = self._persistentkey
out._type = self._type
out._nbytes = self._nbytes
out._persistvirtual = self._persistvirtual
out._array = self._array
if self._setitem is None:
Expand All @@ -100,12 +102,14 @@ def copy(self, generator=None, args=None, kwargs=None, cache=None, persistentkey
out.persistentkey = persistentkey
if type is not None:
out.type = type
if nbytes is not None:
out.nbytes = nbytes
if persistvirtual is not None:
out.persistvirtual = persistvirtual
return out

def deepcopy(self, generator=None, args=None, kwargs=None, cache=None, persistentkey=None, type=None, persistvirtual=None):
out = self.copy(generator=generator, args=arge, kwargs=kwargs, cache=cache, persistentkey=persistentkey, type=type, persistvirtual=persistvirtual)
def deepcopy(self, generator=None, args=None, kwargs=None, cache=None, persistentkey=None, type=None, nbytes=None, persistvirtual=None):
out = self.copy(generator=generator, args=arge, kwargs=kwargs, cache=cache, persistentkey=persistentkey, type=type, nbytes=nbytes, persistvirtual=persistvirtual)
out._array = self._util_deepcopy(out._array)
if out._setitem is not None:
for n in list(out._setitem):
Expand Down Expand Up @@ -242,6 +246,31 @@ def type(self, value):
raise TypeError("type must be None or an awkward type (to set Numpy parameters, use awkward.util.fromnumpy(shape, dtype, masked=False))")
self._type = value

def _getnbytes(self, seen):
if id(self) in seen:
return 0
else:
seen.add(id(self))
if self._nbytes is None or self.ismaterialized:
array = self.array
return (array.nbytes if isinstance(array, self.numpy.ndarray) else array._getnbytes(seen))
else:
return self._nbytes

@property
def nbytes(self):
return self._getnbytes(set())

@nbytes.setter
def nbytes(self, value):
if self.check_prop_valid:
if value is not None:
if not self._util_isinteger(value):
raise TypeError("nbytes must be an integer or None")
if value < 0:
raise ValueError("nbytes must be a non-negative integer or None")
self._nbytes = value

def __len__(self):
return self.shape[0]

Expand Down
2 changes: 1 addition & 1 deletion awkward/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import re

__version__ = "0.8.7"
__version__ = "0.8.8"
version = __version__
version_info = tuple(re.split(r"[-\.]", __version__))

Expand Down
3 changes: 3 additions & 0 deletions tests/test_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class Test(unittest.TestCase):
def runTest(self):
pass

def test_chunked_nbytes(self):
assert isinstance(ChunkedArray([[], [0, 1, 2, 3, 4], [5, 6], [], [7, 8, 9], []]).nbytes, int)

def test_chunked_allslices(self):
chunked = ChunkedArray([[], [0.0, 1.1, 2.2, 3.3, 4.4], [5.5, 6.6], [], [7.7], [8.8, 9.9], []])
regular = numpy.concatenate(chunked.chunks).tolist()
Expand Down
3 changes: 3 additions & 0 deletions tests/test_indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class Test(unittest.TestCase):
def runTest(self):
pass

def test_indexed_nbytes(self):
assert isinstance(IndexedArray([3, 2, 4, 2, 2, 4, 0], [0.0, 1.1, 2.2, 3.3, 4.4]).nbytes, int)

def test_indexed_get(self):
a = IndexedArray([3, 2, 4, 2, 2, 4, 0], [0.0, 1.1, 2.2, 3.3, 4.4])
assert [x for x in a] == [3.3, 2.2, 4.4, 2.2, 2.2, 4.4, 0.0]
Expand Down
3 changes: 3 additions & 0 deletions tests/test_jagged.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class Test(unittest.TestCase):
def runTest(self):
pass

def test_jagged_nbytes(self):
assert isinstance(JaggedArray([0, 3, 3, 5], [3, 3, 5, 10], [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9]).nbytes, int)

def test_jagged_init(self):
a = JaggedArray([0, 3, 3, 5], [3, 3, 5, 10], [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
assert a.tolist() == [[0.0, 1.1, 2.2], [], [3.3, 4.4], [5.5, 6.6, 7.7, 8.8, 9.9]]
Expand Down
3 changes: 3 additions & 0 deletions tests/test_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class Test(unittest.TestCase):
def runTest(self):
pass

def test_masked_nbytes(self):
assert isinstance(MaskedArray([True, False, True, False, True, False, True, False, True, False], [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9], maskedwhen=True).nbytes, int)

def test_masked_get(self):
a = MaskedArray([True, False, True, False, True, False, True, False, True, False], [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9], maskedwhen=True)
assert a.tolist() == [None, 1.1, None, 3.3, None, 5.5, None, 7.7, None, 9.9]
Expand Down
11 changes: 11 additions & 0 deletions tests/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ class Test(unittest.TestCase):
def runTest(self):
pass

def test_object_nbytes(self):
class Point(object):
def __init__(self, array):
self.x, self.y, self.z = array
def __repr__(self):
return "<Point {0} {1} {2}>".format(self.x, self.y, self.z)
def __eq__(self, other):
return isinstance(other, Point) and self.x == other.x and self.y == other.y and self.z == other.z

assert isinstance(ObjectArray([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]], Point).nbytes, int)

def test_object_floats(self):
class Point(object):
def __init__(self, array):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class Test(unittest.TestCase):
def runTest(self):
pass

def test_table_nbytes(self):
assert isinstance(Table([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9]).nbytes, int)

def test_table_get(self):
a = Table([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])

Expand Down
3 changes: 3 additions & 0 deletions tests/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class Test(unittest.TestCase):
def runTest(self):
pass

def test_union_nbytes(self):
assert isinstance(UnionArray([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [[0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9], [0, 100, 200, 300, 400, 500, 600, 700, 800, 900]]).nbytes, int)

def test_union_get(self):
a = UnionArray([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [[0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9], [0, 100, 200, 300, 400, 500, 600, 700, 800, 900]])
assert a.tolist() == [0.0, 100, 2.2, 300, 4.4, 500, 6.6, 700, 8.8, 900]
Expand Down
4 changes: 4 additions & 0 deletions tests/test_virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class Test(unittest.TestCase):
def runTest(self):
pass

def test_virtual_nbytes(self):
assert isinstance(VirtualArray(lambda: [1, 2, 3]).nbytes, int)
assert VirtualArray(lambda: [1, 2, 3], nbytes=12345).nbytes == 12345

def test_virtual_nocache(self):
a = VirtualArray(lambda: [1, 2, 3])
assert not a.ismaterialized
Expand Down

0 comments on commit bd803ca

Please sign in to comment.