Skip to content

Commit

Permalink
Merge pull request #1386 from proektlab/save-nan-fix
Browse files Browse the repository at this point in the history
Compare kind of dtype rather than dtype itself
  • Loading branch information
pgunn authored Aug 9, 2024
2 parents 65d8580 + 7c72e86 commit f32e8ad
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
51 changes: 51 additions & 0 deletions caiman/tests/test_hdf5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python

import numpy as np
import numpy.testing as npt
import os
from caiman.utils import utils
from caiman.paths import get_tempdir


def _recursively_assert_array_equal(a, b):
"""Get around array_equal not ignoring nans for nested objects"""
if isinstance(a, dict):
if not isinstance(b, dict):
raise AssertionError('Values have different types')
if len(a) != len(b):
raise AssertionError('Dicts have different sizes')

for key in a:
if key not in b:
raise AssertionError(f'Dicts have different keys ({key} not found)')
_recursively_assert_array_equal(a[key], b[key])
else:
npt.assert_array_equal(a, b)


def test_save_and_load_dict_to_hdf5():
filename = os.path.join(get_tempdir(), 'test_hdf5.hdf5')
dict_to_save = {
'int_scalar': 1,
'int_vector': np.array([1, 2], dtype=int),
'int_matrix': np.array([[1, 2], [3, 4]], dtype=int),
'float32': np.array([[1., 2.], [3., 4.]], dtype='float32'),
'float32_w_nans': np.array([[1., 2.], [3., np.nan]], dtype='float32'),
'float64_w_nans': np.array([[1., 2.], [3., np.nan]], dtype='float64'),
'dict': {
'nested_float': np.array([1.0, 2.0])
},
'string': 'foobar',
'bool': True,
'dxy': (1.0, 2.0) # specific key that should be saved as a tuple
}
# test no validation error on save
utils.save_dict_to_hdf5(dict_to_save, filename)

# test that the same data gets loaded
loaded = utils.load_dict_from_hdf5(filename)
_recursively_assert_array_equal(dict_to_save, loaded)


if __name__ == '__main__':
test_save_and_load_dict_to_hdf5()
2 changes: 1 addition & 1 deletion caiman/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def recursively_save_dict_contents_to_group(h5file:h5py.File, path:str, dic:dict
except:
item = np.array(item).astype('|S32')
h5file[path + key] = item
if not np.array_equal(h5file[path + key][()], item, equal_nan=item.dtype == 'f'): # just using True gives "ufunc 'isnan' not supported for the input types"
if not np.array_equal(h5file[path + key][()], item, equal_nan=item.dtype.kind == 'f'): # just using True gives "ufunc 'isnan' not supported for the input types"
raise ValueError(f'Error while saving ndarray {key} of dtype {item.dtype}')
# save dictionaries
elif isinstance(item, dict):
Expand Down

0 comments on commit f32e8ad

Please sign in to comment.