Skip to content
This repository has been archived by the owner on Jun 21, 2022. It is now read-only.

Commit

Permalink
special cases and tests for nested Tables and ObjectArrays in JaggedA…
Browse files Browse the repository at this point in the history
…rrays
  • Loading branch information
jpivarski committed Aug 24, 2018
1 parent 8c0a597 commit 8f42fd4
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 8 deletions.
39 changes: 32 additions & 7 deletions awkward/array/jagged.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,9 @@ def _tojagged(self, starts=None, stops=None, copy=True):
return out

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
import awkward.array.objects
import awkward.array.table

self._valid()

if method != "__call__":
Expand All @@ -593,6 +596,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
else:
inputs[i] = inputs[i]._tojagged(starts, stops, copy=False)

elif isinstance(inputs[i], awkward.array.base.AwkwardArray):
pass

else:
inputs[i] = awkward.util.numpy.array(inputs[i], copy=False)

Expand All @@ -604,19 +610,34 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
assert False

for i in range(len(inputs)):
if isinstance(inputs[i], awkward.util.numpy.ndarray):
if isinstance(inputs[i], (awkward.util.numpy.ndarray, awkward.array.base.AwkwardArray)) and not isinstance(inputs[i], JaggedArray):
data = awkward.util.toarray(inputs[i], inputs[i].dtype, (awkward.util.numpy.ndarray, awkward.array.base.AwkwardArray))
if starts.shape != data.shape:
raise ValueError("cannot broadcast JaggedArray of shape {0} with Numpy array of shape {1}".format(starts.shape, data.shape))

if parents is None:
parents = jaggedarray.parents
good = (parents >= 0)

content = awkward.util.numpy.empty(len(parents), dtype=data.dtype)
if len(data.shape) == 0:
content[good] = data
elif starts.shape != data.shape:
raise ValueError("cannot broadcast JaggedArray of shape {0} with Numpy array of shape {1}".format(starts.shape, data.shape))
if isinstance(data, awkward.array.objects.ObjectArray):
content = awkward.util.numpy.empty(len(parents), dtype=data.content.dtype)
content[good] = data.content[parents[good]]
content = data.copy(content=content)

elif isinstance(data, awkward.array.table.Table):
content = data.empty_like()
for n in data.columns:
x = data[n]
content[n] = awkward.util.numpy.empty(len(parents), dtype=x.dtype)
content[n][good] = x[parents[good]]

else:
content[good] = data[parents[good]]
content = awkward.util.numpy.empty(len(parents), dtype=data.dtype)
if len(data.shape) == 0:
content[good] = data
else:
content[good] = data[parents[good]]

inputs[i] = self.copy(starts=starts, stops=stops, content=content)

for i in range(len(inputs)):
Expand All @@ -628,6 +649,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):

result = getattr(ufunc, method)(*inputs, **kwargs)

# print(result)
# print("**************************************")
# print()

if isinstance(result, tuple):
return tuple(self.copy(starts=starts, stops=stops, content=x) for x in result)
elif method == "at":
Expand Down
2 changes: 1 addition & 1 deletion awkward/array/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __contains__(self, name):

def __getattr__(self, name):
if name == "tolist":
return lambda: dict((n, x[self._index]) for n, x in self._table._content.items())
return lambda: dict((n, self._table._try_tolist(x[self._index])) for n, x in self._table._content.items())

content = self._table._content.get(name, None)
if content is not None:
Expand Down
17 changes: 17 additions & 0 deletions tests/test_jagged.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,23 @@ def test_jagged_ufunc(self):
self.assertEqual((100 + a).tolist(), [[100.0, 101.1, 102.2], [], [103.3, 104.4], [105.5, 106.6, 107.7, 108.8, 109.9]])
self.assertEqual((numpy.array([100, 200, 300, 400]) + a).tolist(), [[100.0, 101.1, 102.2], [], [303.3, 304.4], [405.5, 406.6, 407.7, 408.8, 409.9]])

def test_jagged_ufunc_object(self):
a = JaggedArray([0, 3, 3, 5], [3, 3, 5, 10], [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
self.assertEqual((awkward.ObjectArray([100, 200, 300, 400], str) + a).tolist(), ["[100. 101.1 102.2]", "[]", "[303.3 304.4]", "[405.5 406.6 407.7 408.8 409.9]"])

a = JaggedArray([0, 3, 3, 5], [3, 3, 5, 10], awkward.ObjectArray([0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9], str))
self.assertEqual(a, [["0.0", "1.1", "2.2"], [], ["3.3", "4.4"], ["5.5", "6.6", "7.7", "8.8", "9.9"]])
self.assertEqual(a + awkward.ObjectArray([100, 200, 300, 400], str), [["100.0", "101.1", "102.2"], [], ["303.3", "304.4"], ["405.5", "406.6", "407.7", "408.8", "409.9"]])

def test_jagged_ufunc_table(self):
a = JaggedArray([0, 3, 3, 5], [3, 3, 5, 10], [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
self.assertEqual((awkward.Table(x=[100, 200, 300, 400], y=[1000, 2000, 3000, 4000]) + a).tolist(), [{"x": [100.0, 101.1, 102.2], "y": [1000.0, 1001.1, 1002.2]}, {"x": [], "y": []}, {"x": [303.3, 304.4], "y": [3003.3, 3004.4]}, {"x": [405.5, 406.6, 407.7, 408.8, 409.9], "y": [4005.5, 4006.6, 4007.7, 4008.8, 4009.9]}])

a = JaggedArray([0, 3, 3, 5], [3, 3, 5, 10], awkward.Table(x=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], y=[0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9]))
self.assertEqual((a + 1000).tolist(), [[{"x": 1000, "y": 1000.0}, {"x": 1001, "y": 1001.1}, {"x": 1002, "y": 1002.2}], [], [{"x": 1003, "y": 1003.3}, {"x": 1004, "y": 1004.4}], [{"x": 1005, "y": 1005.5}, {"x": 1006, "y": 1006.6}, {"x": 1007, "y": 1007.7}, {"x": 1008, "y": 1008.8}, {"x": 1009, "y": 1009.9}]])
self.assertEqual((a + numpy.array([100, 200, 300, 400])).tolist(), [[{"x": 100, "y": 100.0}, {"x": 101, "y": 101.1}, {"x": 102, "y": 102.2}], [], [{"x": 303, "y": 303.3}, {"x": 304, "y": 304.4}], [{"x": 405, "y": 405.5}, {"x": 406, "y": 406.6}, {"x": 407, "y": 407.7}, {"x": 408, "y": 408.8}, {"x": 409, "y": 409.9}]])
self.assertEqual((a + awkward.Table(x=[100, 200, 300, 400], y=[1000, 2000, 3000, 4000])).tolist(), [[{"x": 100, "y": 1000.0}, {"x": 101, "y": 1001.1}, {"x": 102, "y": 1002.2}], [], [{"x": 303, "y": 3003.3}, {"x": 304, "y": 3004.4}], [{"x": 405, "y": 4005.5}, {"x": 406, "y": 4006.6}, {"x": 407, "y": 4007.7}, {"x": 408, "y": 4008.8}, {"x": 409, "y": 4009.9}]])

def test_jagged_cross(self):
pass

Expand Down

0 comments on commit 8f42fd4

Please sign in to comment.