diff --git a/mmap_ninja/src/mmap_ninja/parallel.py b/mmap_ninja/src/mmap_ninja/parallel.py index 0ab8484..1094982 100644 --- a/mmap_ninja/src/mmap_ninja/parallel.py +++ b/mmap_ninja/src/mmap_ninja/parallel.py @@ -1,3 +1,4 @@ +from enum import Enum from functools import partial from tqdm.auto import tqdm @@ -10,7 +11,11 @@ HAS_JOBLIB = True -EXHAUSTED = '__EXHAUSTED__' +class _Exhausted(Enum): + exhausted = 'EXHAUSTED' + + +EXHAUSTED = _Exhausted.exhausted class ParallelBatchCollector: @@ -84,7 +89,7 @@ def _collect_no_parallel_batch(self): results = [_get_from_indexable(self.indexable, j) for j in self._rng()] if self.exhausted(results): - results = [r for r in results if not isinstance(r, str) or r != EXHAUSTED] + results = [r for r in results if r is not EXHAUSTED] return results @@ -94,7 +99,7 @@ def _collect_parallel_batch(self): results = self._parallel(func(j) for j in self._rng()) if self.exhausted(results): - results = [r for r in results if not isinstance(r, str) or r != EXHAUSTED] + results = [r for r in results if r is not EXHAUSTED] self._parallel.__exit__(None, None, None) return results @@ -102,7 +107,7 @@ def _collect_parallel_batch(self): def exhausted(self, results=()): self._exhausted = ( self._exhausted or - any(isinstance(r, str) and r == EXHAUSTED for r in results) or + any(r is EXHAUSTED for r in results) or self.completed_batches() )