diff --git a/hoi/utils/stats.py b/hoi/utils/stats.py index 34dda36f..e6bbf501 100644 --- a/hoi/utils/stats.py +++ b/hoi/utils/stats.py @@ -250,6 +250,8 @@ def get_nbest_mult( import pandas as pd hoi = np.asarray(hoi).squeeze() + if not hoi.ndim: + hoi = hoi.reshape(-1) # get order and multiplets if model: