Skip to content
This repository has been archived by the owner on Jun 21, 2022. It is now read-only.

Commit

Permalink
Merge pull request #194 from scikit-hep/issue-180
Browse files Browse the repository at this point in the history
Allow ChunkedArrays to be crossed (and argcrossed) with ChunkedArrays that have the same chunk sizes.
  • Loading branch information
jpivarski authored Sep 27, 2019
2 parents 975649a + b22acc7 commit ff50522
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions awkward/array/chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def __getitem__(self, where):
where = (where,)
head, tail = where[0], where[1:]

if isinstance(head, self.ChunkedArray):
if isinstance(head, ChunkedArray):
if not self._aligned(head):
raise ValueError("A ChunkedArray can only be used as a slice of a ChunkedArray if they have the same chunk sizes")
chunks = []
Expand Down Expand Up @@ -630,14 +630,32 @@ def argpairs(self, nested=False):
return out

def cross(self, other, nested=False):
out = self.copy(chunks=[x.cross(other, nested=nested) for x in self._chunks])
out.knowchunksizes()
return out
if not isinstance(other, ChunkedArray) or not self._aligned(other):
raise ValueError("A ChunkedArray can only be crossed with a ChunkedArray if they have the same chunk sizes")
chunks = []
chunksizes = []
for c, h in zip(self.chunks, other.chunks):
if isinstance(c, awkward.array.virtual.VirtualArray):
c = c.array
if isinstance(h, awkward.array.virtual.VirtualArray):
h = h.array
chunks.append(c.cross(h, nested=nested))
chunksizes.append(len(chunks[-1]))
return self.copy(chunks=chunks, chunksizes=chunksizes)

def argcross(self, other, nested=False):
out = self.copy(chunks=[x.argcross(other, nested=nested) for x in self._chunks])
out.knowchunksizes()
return out
if not isinstance(other, ChunkedArray) or not self._aligned(other):
raise ValueError("A ChunkedArray can only be crossed with a ChunkedArray if they have the same chunk sizes")
chunks = []
chunksizes = []
for c, h in zip(self.chunks, other.chunks):
if isinstance(c, awkward.array.virtual.VirtualArray):
c = c.array
if isinstance(h, awkward.array.virtual.VirtualArray):
h = h.array
chunks.append(c.argcross(h, nested=nested))
chunksizes.append(len(chunks[-1]))
return self.copy(chunks=chunks, chunksizes=chunksizes)

def flattentuple(self):
return self.copy(chunks=[self._util_flattentuple(x) for x in self._chunks], chunksizes=self._chunksizes)
Expand Down

0 comments on commit ff50522

Please sign in to comment.