diff --git a/.gitignore b/.gitignore index ae8d540..51e101d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ historical/ -# +# # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -92,3 +92,5 @@ ENV/ *.DS_Store *.orig + +.pytest_cache diff --git a/grabbit/core.py b/grabbit/core.py index 77616e7..0b77c3d 100644 --- a/grabbit/core.py +++ b/grabbit/core.py @@ -273,10 +273,7 @@ def match_file(self, f, update_file=False): m = self.regex.search(f.path) val = m.group(1) if m is not None else None - if val is not None and self.dtype is not None: - val = self.dtype(val) - - return val + return self._astype(val) def add_file(self, filename, value): """ Adds the specified filename to tracking. """ @@ -296,6 +293,11 @@ def count(self, files=False): """ return len(self.files) if files else len(self.unique()) + def _astype(self, val): + if val is not None and self.dtype is not None: + val = self.dtype(val) + return val + class Layout(object): @@ -860,7 +862,7 @@ def get_nearest(self, path, return_type='file', strict=True, all_=False, for ent in self.entities.values(): m = ent.regex.search(path) if m: - entities[ent.name] = m.group(1) + entities[ent.name] = ent._astype(m.group(1)) # Remove any entities we want to ignore when strict matching is on if strict and ignore_strict_entities is not None: diff --git a/grabbit/tests/test_core.py b/grabbit/tests/test_core.py index 61dcb0b..29eb02a 100644 --- a/grabbit/tests/test_core.py +++ b/grabbit/tests/test_core.py @@ -243,6 +243,7 @@ def test_dynamic_getters(self, data_dir, config): layout = Layout([(data_dir, config)], dynamic_getters=True) assert hasattr(layout, 'get_subjects') assert '01' in getattr(layout, 'get_subjects')() + assert 1 in getattr(layout, 'get_runs')() def test_querying(self, bids_layout): @@ -312,6 +313,12 @@ def test_get_nearest(self, bids_layout): assert len(nearest) == 3 assert nearest[0].subject == '01' + # Check for file with matching run (fails if types don't match) + nearest = bids_layout.get_nearest( + result, type='phasediff', extensions='.nii.gz') + assert nearest is not None + assert os.path.basename(nearest) == 'sub-01_ses-1_run-1_phasediff.nii.gz' + def test_index_regex(self, bids_layout, layout_include): targ = join('derivatives', 'excluded.json') assert targ not in bids_layout.files diff --git a/grabbit/utils.py b/grabbit/utils.py index cb4bf64..7e4d143 100644 --- a/grabbit/utils.py +++ b/grabbit/utils.py @@ -13,6 +13,8 @@ def natural_sort(l, field=None): def alphanum_key(key): if field is not None: key = getattr(key, field) + if not isinstance(key, str): + key = str(key) return [convert(c) for c in re.split('([0-9]+)', key)] return sorted(l, key=alphanum_key)