Skip to content

Commit

Permalink
added weighted centering
Browse files Browse the repository at this point in the history
  • Loading branch information
lennybronner committed Oct 1, 2024
1 parent e551069 commit 0232bc3
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/elexsolver/OLSRegressionSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def _compute_normal_equations(
# lambda_I is the regularization matrix
return np.linalg.inv(R.T @ R + lambda_I) @ R.T @ Q.T

def _normalize_weights(self, weights):
return weights / weights.sum()

def fit(
self,
x: np.ndarray,
Expand All @@ -105,7 +108,7 @@ def fit(

# normalize weights and turn into diagional matrix
# square root because will be squared when R^T R happens later
L = np.diag(np.sqrt(weights.flatten() / weights.sum()))
L = np.diag(np.sqrt(self._normalize_weights(weights.flatten())))

# if normal equations are provided then use those, otherwise compute them
# in the bootstrap setting we can now pass in the normal equations and can
Expand All @@ -123,7 +126,7 @@ def fit(
# compute coefficients: (X^T X)^{-1} X^T y
self.coefficients = self.normal_eqs @ L @ y

def residuals(self, y: np.ndarray, y_hat: np.ndarray, loo: bool = True, center: bool = True) -> np.ndarray:
def residuals(self, y: np.ndarray, y_hat: np.ndarray, weights: np.ndarray, loo: bool = True, center: bool = True) -> np.ndarray:
"""
Computes residuals for the model
"""
Expand All @@ -136,8 +139,8 @@ def residuals(self, y: np.ndarray, y_hat: np.ndarray, loo: bool = True, center:
if loo:
residuals /= (1 - self.hat_vals).reshape(-1, 1)

# centering removes the column mean
# centering removes the column weighted mean
if center:
residuals -= np.mean(residuals, axis=0)
residuals -= np.sum(residuals * self._normalize_weights(weights), axis=0)

return residuals

0 comments on commit 0232bc3

Please sign in to comment.