diff --git a/src/pydvl/valuation/result.py b/src/pydvl/valuation/result.py index 1406d2c3c..776836910 100644 --- a/src/pydvl/valuation/result.py +++ b/src/pydvl/valuation/result.py @@ -263,7 +263,7 @@ def __init__( if indices is None: indices = np.arange(len(self._values), dtype=np.int_) - self._indices = np.array(indices, dtype=np.int_, copy=False) + self._indices = np.array(indices, dtype=indices.dtype, copy=False) self._positions = {idx: pos for pos, idx in enumerate(indices)} self._sort_positions: NDArray[np.int_] = np.arange(