diff --git a/pylabianca/decoding.py b/pylabianca/decoding.py index d64f436..94fb6d2 100644 --- a/pylabianca/decoding.py +++ b/pylabianca/decoding.py @@ -641,6 +641,19 @@ def correlation(X1, X2): return rval_sel +def corr_rows(A, B): + # Rowwise mean of input arrays & subtract from input arrays themeselves + A_mA = A - A.mean(axis=1)[:, None] + B_mB = B - B.mean(axis=1)[:, None] + + # Sum of squares across rows + ssA = (A_mA ** 2).sum(axis=1) + ssB = (B_mB ** 2).sum(axis=1) + + # Finally get corr coeff + return np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None], ssB[None])) + + # TODO: profile and consider using a better correlation function # (numba for example) class maxCorrClassifier(BaseEstimator): @@ -681,7 +694,7 @@ def fit(self, X, y, scoring=None): avg = X[msk, :].mean(axis=0) self.class_averages_.append(avg) - self.class_averages_ = np.stack(self.class_averages_, axis=1) + self.class_averages_ = np.stack(self.class_averages_, axis=0) self.scoring = 'accuracy' if scoring is None else scoring return self @@ -719,16 +732,16 @@ def predict(self, X): message='invalid value encountered in true_divide', category=RuntimeWarning ) - r = correlation(self.class_averages_, X.T) + r = corr_rows(self.class_averages_, X) else: - distance = self.class_averages_[..., None] - X.T[:, None, :] - r = np.linalg.norm(distance, axis=0) * -1 + distance = self.class_averages_[:, None, :] - X[None, :, :] + r = np.linalg.norm(distance, axis=-1) * -1 bad_trials = np.isnan(r).all(axis=0) if bad_trials.any(): - distance = (self.class_averages_[..., None] - - X.T[:, None, bad_trials]) - r[:, bad_trials] = np.linalg.norm(distance, axis=0) * -1 + distance = (self.class_averages_[:, None, :] + - X[None, bad_trials, :]) + r[:, bad_trials] = np.linalg.norm(distance, axis=-1) * -1 # pick class with best correlation: r_best = r.argmax(axis=0)