Skip to content
This repository has been archived by the owner on Jun 21, 2022. It is now read-only.

Commit

Permalink
Merge pull request #160 from nsmith-/argsort
Browse files Browse the repository at this point in the history
Implement jagged argsort and speed up argminmax
  • Loading branch information
jpivarski authored Jul 17, 2019
2 parents 513f28a + 10a3f97 commit e7496ff
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 57 deletions.
106 changes: 50 additions & 56 deletions awkward/array/jagged.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def offsets2parents(cls, offsets):

@classmethod
def startsstops2parents(cls, starts, stops):
assert starts.shape == stops.shape
starts = starts.reshape(-1) # flatten in case multi-d jagged
stops = stops.reshape(-1)
dtype = cls.JaggedArray.fget(None).INDEXTYPE
out = cls.numpy.full(stops.max(), -1, dtype=dtype)
indices = cls.numpy.arange(len(starts), dtype=dtype)
Expand Down Expand Up @@ -390,9 +393,9 @@ def counts(self, value):
def parents(self):
if self._parents is None:
self._valid()
try:
if self._canuseoffset():
self._parents = self.offsets2parents(self.offsets)
except ValueError:
else:
self._parents = self.startsstops2parents(self._starts, self._stops)
return self._parents

Expand All @@ -417,10 +420,11 @@ def localindex(self):
out -= self.offsets[self.parents[self.parents >= 0]]
return self.JaggedArray.fromoffsets(self.offsets - self.offsets[0], out)
else:
offsets = self.counts2offsets(self.counts)
counts = self.counts.reshape(-1) # flatten in case multi-d jagged
offsets = self.counts2offsets(counts)
out = self.numpy.arange(offsets[-1], dtype=self.INDEXTYPE)
out -= self.numpy.repeat(offsets[:-1], awkward.util.windows_safe(self.counts))
return self.JaggedArray.fromoffsets(offsets, out)
out -= self.numpy.repeat(offsets[:-1], awkward.util.windows_safe(counts))
return self.JaggedArray(offsets[:-1].reshape(self.shape), offsets[1:].reshape(self.shape), out)

def _getnbytes(self, seen):
if id(self) in seen:
Expand Down Expand Up @@ -567,6 +571,11 @@ def __getitem__(self, where):

return self.copy(starts=offsets[:-1].reshape(intheadsum.shape), stops=offsets[1:].reshape(intheadsum.shape), content=thyself._content[headcontent])

elif head.shape == self.shape and issubclass(head._content.dtype.type, (self.numpy.bool, self.numpy.bool_)):
index = self.localindex + self.starts
flatindex = index.flatten()[head.flatten()]
return self.JaggedArray.fromcounts(head.sum(), self._content[flatindex])

else:
raise TypeError("jagged index must be boolean (mask) or integer (fancy indexing)")

Expand Down Expand Up @@ -954,15 +963,15 @@ def recurse(x):
if len(x.shape) == 0:
content = self.numpy.full(len(parents), x, dtype=x.dtype)
else:
content = x[parents]
content = x.reshape(-1)[parents]
return content

else:
content = self.numpy.empty(len(parents), dtype=x.dtype)
if len(x.shape) == 0:
content[good] = x
else:
content[good] = x[parents[good]]
content[good] = x.reshape(-1)[parents[good]]
return content

content = recurse(data)
Expand Down Expand Up @@ -1516,60 +1525,45 @@ def _argminmax(self, ismin):
if len(self._content) == 0:
return self.copy(content=self.numpy.array([], dtype=self.INDEXTYPE))

contentmax = self._content.max()
shiftval = self.numpy.ceil(contentmax) + 1
if math.isnan(shiftval) or math.isinf(shiftval) or shiftval != contentmax:
return self._argminmax_general(ismin)

flatstarts = self._starts.reshape(-1)
flatstops = self._stops.reshape(-1)

nonempty = (flatstarts != flatstops)
nonterminal = (flatstarts < len(self._content))
flatstarts = flatstarts[nonterminal]
flatstops = flatstops[nonterminal]

shift = self.numpy.zeros(self._content.shape, dtype=self.INDEXTYPE)
shift[flatstarts] = shiftval
self.numpy.cumsum(shift, out=shift)

sortedindex = (self._content + shift).argsort()

if ismin:
flatout = sortedindex[flatstarts] - flatstarts
out = self.localindex[self.min() == self]
else:
flatout = sortedindex[flatstops - 1] - flatstarts

newstarts = self.numpy.arange(len(nonempty), dtype=self.INDEXTYPE).reshape(self._starts.shape)
newstops = self.numpy.array(newstarts)
newstops.reshape(-1)[nonempty] += 1
return self.copy(starts=newstarts, stops=newstops, content=flatout)
out = self.localindex[self.max() == self]

def _argminmax_general(self, ismin):
if len(self._content.shape) != 1:
raise ValueError("cannot compute arg{0} because content is not one-dimensional".format("min" if ismin else "max"))
# workaround for lack of general out[...,:1] support
nonempty = out.counts > 0
out.stops[nonempty] = out.starts[nonempty] + 1
return out

if ismin:
optimum = self.numpy.argmin
def argsort(self, ascending=False):
self._valid()
if self._util_hasjagged(self._content):
return self.copy(content=self._content.argsort(ascending))
else:
optimum = self.numpy.argmax

out = self.numpy.empty(self._starts.shape + self._content.shape[1:], dtype=self.INDEXTYPE)

flatout = out.reshape((-1,) + self._content.shape[1:])
flatstarts = self._starts.reshape(-1)
flatstops = self._stops.reshape(-1)

content = self._content
for i, flatstart in enumerate(flatstarts):
flatstop = flatstops[i]
if flatstart != flatstop:
flatout[i] = optimum(content[flatstart:flatstop], axis=0)

newstarts = self.numpy.arange(len(flatstarts), dtype=self.INDEXTYPE).reshape(self._starts.shape)
newstops = self.numpy.array(newstarts)
newstops.reshape(-1)[flatstarts != flatstops] += 1
return self.copy(starts=newstarts, stops=newstops, content=flatout)
return self._argsort(ascending)

def _argsort(self, ascending=False):
reducer = self.JaggedArray.min if ascending else self.JaggedArray.max
localindex = self.localindex
out = localindex.empty_like()
next_start = self.numpy.zeros_like(out.starts)
tmp = self.copy()
while tmp.content.size > 0:
best = reducer(tmp) == tmp
if self.numpy.isnan(tmp.content).all():
# put NaN last always
best = self.numpy.isnan(tmp)
argbest = localindex[best]
idx = out.starts + next_start + argbest.localindex
out._content[idx.flatten()] = argbest.content
next_start += argbest.counts
tmp = tmp[~best]
localindex = localindex[~best]

# If masked entries were present, they would be dropped by the __getitem__
# So we need to trim the size of the output array correspondingly
out.stops = out.starts + next_start
return out

@classmethod
def _concatenate_axis0(cls, arrays):
Expand Down
2 changes: 1 addition & 1 deletion awkward/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import re

__version__ = "0.12.2"
__version__ = "0.12.3"
version = __version__
version_info = tuple(re.split(r"[-\.]", __version__))

Expand Down
5 changes: 5 additions & 0 deletions tests/test_jagged.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,3 +594,8 @@ def test_jagged_parents(self):
assert a.parents.tolist() == [0, 0, 0, 0, 0, 2, 2, 2, 3]
b = a[[False, True, False, True]]
assert b.parents.tolist() == [-1, -1, -1, -1, -1, -1, -1, -1, 1]

def test_jagged_sort(self):
a = awkward.fromiter([[2.,3.,1.], [4., -numpy.inf, 5.], [numpy.inf, 4., numpy.nan, -numpy.inf], [numpy.nan], [3., None, 4., -1.]])
assert a.argsort().tolist() == [[1, 0, 2], [2, 0, 1], [0, 1, 3, 2], [0], [2, 0, 3]]
assert a.argsort(True).tolist() == [[2, 0, 1], [1, 0, 2], [3, 1, 0, 2], [0], [3, 0, 2]]

0 comments on commit e7496ff

Please sign in to comment.