Skip to content

Commit

Permalink
Merge pull request #344 from mesenrj/fix/pred-dist-memory-usage
Browse files Browse the repository at this point in the history
Fix issue causing large memory consumption in pred_dist()
  • Loading branch information
ryan-wolbeck authored Jan 31, 2024
2 parents c482aab + b45f7b6 commit ffbd359
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions ngboost/ngboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ def partial_fit(
# if early stopping is specified, split X,Y and sample weights (if given) into training and validation sets
# This will overwrite any X_val and Y_val values passed by the user directly.
if self.early_stopping_rounds is not None:

early_stopping_rounds = self.early_stopping_rounds

if sample_weight is None:
Expand Down Expand Up @@ -489,13 +488,9 @@ def pred_dist(self, X, max_iter=None):

X = check_array(X, accept_sparse=True)

if (
max_iter is not None
): # get prediction at a particular iteration if asked for
dist = self.staged_pred_dist(X, max_iter=max_iter)[-1]
else:
params = np.asarray(self.pred_param(X, max_iter))
dist = self.Dist(params.T)
params = np.asarray(self.pred_param(X, max_iter))
dist = self.Dist(params.T)

return dist

def staged_pred_dist(self, X, max_iter=None):
Expand Down

0 comments on commit ffbd359

Please sign in to comment.