diff --git a/src/elexsolver/OLSRegressionSolver.py b/src/elexsolver/OLSRegressionSolver.py index ca05535..ef8d0bb 100644 --- a/src/elexsolver/OLSRegressionSolver.py +++ b/src/elexsolver/OLSRegressionSolver.py @@ -82,6 +82,9 @@ def _compute_normal_equations( # lambda_I is the regularization matrix return np.linalg.inv(R.T @ R + lambda_I) @ R.T @ Q.T + def _normalize_weights(self, weights): + return weights / weights.sum() + def fit( self, x: np.ndarray, @@ -105,7 +108,7 @@ def fit( # normalize weights and turn into diagional matrix # square root because will be squared when R^T R happens later - L = np.diag(np.sqrt(weights.flatten() / weights.sum())) + L = np.diag(np.sqrt(self._normalize_weights(weights.flatten()))) # if normal equations are provided then use those, otherwise compute them # in the bootstrap setting we can now pass in the normal equations and can @@ -123,7 +126,7 @@ def fit( # compute coefficients: (X^T X)^{-1} X^T y self.coefficients = self.normal_eqs @ L @ y - def residuals(self, y: np.ndarray, y_hat: np.ndarray, loo: bool = True, center: bool = True) -> np.ndarray: + def residuals(self, y: np.ndarray, y_hat: np.ndarray, weights: np.ndarray, loo: bool = True, center: bool = True) -> np.ndarray: """ Computes residuals for the model """ @@ -136,8 +139,8 @@ def residuals(self, y: np.ndarray, y_hat: np.ndarray, loo: bool = True, center: if loo: residuals /= (1 - self.hat_vals).reshape(-1, 1) - # centering removes the column mean + # centering removes the column weighted mean if center: - residuals -= np.mean(residuals, axis=0) + residuals -= np.sum(residuals * self._normalize_weights(weights), axis=0) return residuals