Skip to content

Commit

Permalink
Make Indexer classes not inherit from tuple. (#1705)
Browse files Browse the repository at this point in the history
* Make Indexer classes not inherit from tuple.

I'm not entirely sure this is a good idea. The advantage is that it ensures that
all our indexing code is entirely explicit: everything that reaches a backend
*must* be an ExplicitIndexer. The downside is that it removes a bit of
internal flexibility: we can't just use tuples in place of basic indexers
anymore. On the whole, I think this is probably worth it but I would appreciate
feedback.

* Add validation to ExplicitIndexer classes

* Fix pynio test failure

* Rename and add comments

* flake8

* Fix windows test failure

* typo

* leftover from debugging
  • Loading branch information
shoyer authored Nov 14, 2017
1 parent 9d8ec38 commit ac854f0
Show file tree
Hide file tree
Showing 19 changed files with 515 additions and 262 deletions.
10 changes: 9 additions & 1 deletion xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from distutils.version import LooseVersion

from ..conventions import cf_encoder
from ..core.utils import FrozenOrderedDict
from ..core import indexing
from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin
from ..core.pycompat import iteritems, dask_array_type

try:
Expand Down Expand Up @@ -76,6 +77,13 @@ def robust_getitem(array, key, catch=Exception, max_retries=6,
time.sleep(1e-3 * next_delay)


class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):

def __array__(self, dtype=None):
key = indexing.BasicIndexer((slice(None),) * self.ndim)
return np.asarray(self[key], dtype=dtype)


class AbstractDataStore(Mapping):
_autoclose = False

Expand Down
7 changes: 2 additions & 5 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@

class H5NetCDFArrayWrapper(BaseNetCDF4Array):
def __getitem__(self, key):
if isinstance(key, indexing.VectorizedIndexer):
raise NotImplementedError(
'Vectorized indexing for {} is not implemented. Load your '
'data first with .load() or .compute().'.format(type(self)))
key = indexing.to_tuple(key)
key = indexing.unwrap_explicit_indexer(
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))
with self.datastore.ensure_open(autoclose=True):
return self.get_array()[key]

Expand Down
17 changes: 5 additions & 12 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
from .. import Variable
from ..conventions import pop_to
from ..core import indexing
from ..core.utils import (FrozenOrderedDict, NdimSizeLenMixin,
DunderArrayMixin, close_on_error,
is_remote_uri)
from ..core.utils import (FrozenOrderedDict, close_on_error, is_remote_uri)
from ..core.pycompat import iteritems, basestring, OrderedDict, PY3, suppress

from .common import (WritableCFDataStore, robust_getitem,
from .common import (WritableCFDataStore, robust_getitem, BackendArray,
DataStorePickleMixin, find_root)
from .netcdf3 import (encode_nc3_attr_value, encode_nc3_variable)

Expand All @@ -27,8 +25,7 @@
'|': 'native'}


class BaseNetCDF4Array(NdimSizeLenMixin, DunderArrayMixin,
indexing.NDArrayIndexable):
class BaseNetCDF4Array(BackendArray):
def __init__(self, variable_name, datastore):
self.datastore = datastore
self.variable_name = variable_name
Expand All @@ -51,12 +48,8 @@ def get_array(self):

class NetCDF4ArrayWrapper(BaseNetCDF4Array):
def __getitem__(self, key):
if isinstance(key, indexing.VectorizedIndexer):
raise NotImplementedError(
'Vectorized indexing for {} is not implemented. Load your '
'data first with .load() or .compute().'.format(type(self)))

key = indexing.to_tuple(key)
key = indexing.unwrap_explicit_indexer(
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))

if self.datastore.is_remote: # pragma: no cover
getitem = functools.partial(robust_getitem, catch=RuntimeError)
Expand Down
20 changes: 6 additions & 14 deletions xarray/backends/pydap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import numpy as np

from .. import Variable
from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin
from ..core.utils import FrozenOrderedDict, Frozen
from ..core import indexing
from ..core.pycompat import integer_types

from .common import AbstractDataStore, robust_getitem
from .common import AbstractDataStore, BackendArray, robust_getitem


class PydapArrayWrapper(NDArrayMixin, indexing.NDArrayIndexable):
class PydapArrayWrapper(BackendArray):
def __init__(self, array):
self.array = array

Expand All @@ -27,17 +27,9 @@ def dtype(self):
return np.dtype(t.typecode + str(t.size))

def __getitem__(self, key):
if isinstance(key, indexing.VectorizedIndexer):
raise NotImplementedError(
'Vectorized indexing for {} is not implemented. Load your '
'data first with .load() or .compute().'.format(type(self)))
key = indexing.to_tuple(key)
if not isinstance(key, tuple):
key = (key,)
for k in key:
if not (isinstance(k, integer_types + (slice,)) or k is Ellipsis):
raise IndexError('pydap only supports indexing with int, '
'slice and Ellipsis objects')
key = indexing.unwrap_explicit_indexer(
key, target=self, allow=indexing.BasicIndexer)

# pull the data from the array attribute if possible, to avoid
# downloading coordinate data twice
array = getattr(self.array, 'array', self.array)
Expand Down
19 changes: 6 additions & 13 deletions xarray/backends/pynio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@
import numpy as np

from .. import Variable
from ..core.utils import (FrozenOrderedDict, Frozen,
NdimSizeLenMixin, DunderArrayMixin)
from ..core.utils import (FrozenOrderedDict, Frozen)
from ..core import indexing
from ..core.pycompat import integer_types

from .common import AbstractDataStore, DataStorePickleMixin
from .common import AbstractDataStore, DataStorePickleMixin, BackendArray


class NioArrayWrapper(NdimSizeLenMixin, DunderArrayMixin,
indexing.NDArrayIndexable):
class NioArrayWrapper(BackendArray):

def __init__(self, variable_name, datastore):
self.datastore = datastore
Expand All @@ -30,13 +27,9 @@ def get_array(self):
return self.datastore.ds.variables[self.variable_name]

def __getitem__(self, key):
if isinstance(key, (indexing.VectorizedIndexer,
indexing.OuterIndexer)):
raise NotImplementedError(
'Nio backend does not support vectorized / outer indexing. '
'Load your data first with .load() or .compute(). '
'Given {}'.format(key))
key = indexing.to_tuple(key)
key = indexing.unwrap_explicit_indexer(
key, target=self, allow=indexing.BasicIndexer)

with self.datastore.ensure_open(autoclose=True):
array = self.get_array()
if key == () and self.ndim == 0:
Expand Down
14 changes: 6 additions & 8 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import numpy as np

from .. import DataArray
from ..core.utils import DunderArrayMixin, NdimSizeLenMixin, is_scalar
from ..core.utils import is_scalar
from ..core import indexing
from .common import BackendArray
try:
from dask.utils import SerializableLock as Lock
except ImportError:
Expand All @@ -17,8 +18,7 @@
'first.')


class RasterioArrayWrapper(NdimSizeLenMixin, DunderArrayMixin,
indexing.NDArrayIndexable):
class RasterioArrayWrapper(BackendArray):
"""A wrapper around rasterio dataset objects"""
def __init__(self, rasterio_ds):
self.rasterio_ds = rasterio_ds
Expand All @@ -38,11 +38,9 @@ def shape(self):
return self._shape

def __getitem__(self, key):
if isinstance(key, indexing.VectorizedIndexer):
raise NotImplementedError(
'Vectorized indexing for {} is not implemented. Load your '
'data first with .load() or .compute().'.format(type(self)))
key = indexing.to_tuple(key)
key = indexing.unwrap_explicit_indexer(
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))

# bands cannot be windowed but they can be listed
band_key = key[0]
n_bands = self.shape[0]
Expand Down
9 changes: 4 additions & 5 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@

from .. import Variable
from ..core.pycompat import iteritems, OrderedDict, basestring
from ..core.utils import (Frozen, FrozenOrderedDict, NdimSizeLenMixin,
DunderArrayMixin)
from ..core.indexing import NumpyIndexingAdapter, NDArrayIndexable
from ..core.utils import (Frozen, FrozenOrderedDict)
from ..core.indexing import NumpyIndexingAdapter

from .common import WritableCFDataStore, DataStorePickleMixin
from .common import WritableCFDataStore, DataStorePickleMixin, BackendArray
from .netcdf3 import (is_valid_nc3_name, encode_nc3_attr_value,
encode_nc3_variable)

Expand All @@ -31,7 +30,7 @@ def _decode_attrs(d):
for (k, v) in iteritems(d))


class ScipyArrayWrapper(NdimSizeLenMixin, DunderArrayMixin, NDArrayIndexable):
class ScipyArrayWrapper(BackendArray):

def __init__(self, variable_name, datastore):
self.datastore = datastore
Expand Down
29 changes: 15 additions & 14 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def encode_cf_timedelta(timedeltas, units=None):
return (num, units)


class MaskedAndScaledArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
class MaskedAndScaledArray(indexing.ExplicitlyIndexedNDArrayMixin):
"""Wrapper around array-like objects to create a new indexable object where
values, when accessed, are automatically scaled and masked according to
CF conventions for packed and missing data values.
Expand Down Expand Up @@ -395,7 +395,7 @@ def __repr__(self):
self.scale_factor, self.add_offset, self._dtype))


class DecodedCFDatetimeArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
class DecodedCFDatetimeArray(indexing.ExplicitlyIndexedNDArrayMixin):
"""Wrapper around array-like objects to create a new indexable object where
values, when accessed, are automatically converted into datetime objects
using decode_cf_datetime.
Expand All @@ -408,8 +408,9 @@ def __init__(self, array, units, calendar=None):
# Verify that at least the first and last date can be decoded
# successfully. Otherwise, tracebacks end up swallowed by
# Dataset.__repr__ when users try to view their lazily decoded array.
example_value = np.concatenate([first_n_items(array, 1) or [0],
last_item(array) or [0]])
values = indexing.ImplicitToExplicitIndexingAdapter(self.array)
example_value = np.concatenate([first_n_items(values, 1) or [0],
last_item(values) or [0]])

try:
result = decode_cf_datetime(example_value, units, calendar)
Expand All @@ -434,7 +435,7 @@ def __getitem__(self, key):
calendar=self.calendar)


class DecodedCFTimedeltaArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
class DecodedCFTimedeltaArray(indexing.ExplicitlyIndexedNDArrayMixin):
"""Wrapper around array-like objects to create a new indexable object where
values, when accessed, are automatically converted into timedelta objects
using decode_cf_timedelta.
Expand All @@ -451,7 +452,7 @@ def __getitem__(self, key):
return decode_cf_timedelta(self.array[key], units=self.units)


class StackedBytesArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
class StackedBytesArray(indexing.ExplicitlyIndexedNDArrayMixin):
"""Wrapper around array-like objects to create a new indexable object where
values, when accessed, are automatically stacked along the last dimension.
Expand Down Expand Up @@ -482,7 +483,7 @@ def shape(self):
def __str__(self):
# TODO(shoyer): figure out why we need this special case?
if self.ndim == 0:
return str(self[...].item())
return str(np.array(self).item())
else:
return repr(self)

Expand All @@ -491,13 +492,13 @@ def __repr__(self):

def __getitem__(self, key):
# require slicing the last dimension completely
key = indexing.expanded_indexer(key, self.array.ndim)
if key[-1] != slice(None):
key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim))
if key.tuple[-1] != slice(None):
raise IndexError('too many indices')
return char_to_bytes(self.array[key])


class BytesToStringArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
class BytesToStringArray(indexing.ExplicitlyIndexedNDArrayMixin):
"""Wrapper that decodes bytes to unicode when values are read.
>>> BytesToStringArray(np.array([b'abc']))[:]
Expand All @@ -524,7 +525,7 @@ def dtype(self):
def __str__(self):
# TODO(shoyer): figure out why we need this special case?
if self.ndim == 0:
return str(self[...].item())
return str(np.array(self).item())
else:
return repr(self)

Expand All @@ -536,7 +537,7 @@ def __getitem__(self, key):
return decode_bytes_array(self.array[key], self.encoding)


class NativeEndiannessArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin):
"""Decode arrays on the fly from non-native to native endianness
This is useful for decoding arrays from netCDF3 files (which are all
Expand Down Expand Up @@ -565,7 +566,7 @@ def __getitem__(self, key):
return np.asarray(self.array[key], dtype=self.dtype)


class BoolTypeArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
"""Decode arrays on the fly from integer to boolean datatype
This is useful for decoding boolean arrays from integer typed netCDF
Expand Down Expand Up @@ -593,7 +594,7 @@ def __getitem__(self, key):
return np.asarray(self.array[key], dtype=self.dtype)


class UnsignedIntTypeArray(utils.NDArrayMixin, indexing.NDArrayIndexable):
class UnsignedIntTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
"""Decode arrays on the fly from signed integer to unsigned
integer. Typically used when _Unsigned is set at as a netCDF
attribute on a signed integer variable.
Expand Down
Loading

0 comments on commit ac854f0

Please sign in to comment.