Skip to content

Commit

Permalink
bugfix for NParEGO: default scalarization_weights is None (random wei…
Browse files Browse the repository at this point in the history
…ghts); enable user input scalarization_weights (previous missing in _parse_aq_kwargs
  • Loading branch information
xuyuting committed Aug 19, 2024
1 parent d2df76d commit fcf8220
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion obsidian/acquisition/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@
'NIPV': {},
'EHVI': {'ref_point': {'val': None, 'optional': True}},
'NEHVI': {'ref_point': {'val': None, 'optional': True}},
'NParEGO': {'scalarization_weights': {'val': [1], 'optional': True}},
'NParEGO': {'scalarization_weights': {'val': None, 'optional': True}},
}
9 changes: 8 additions & 1 deletion obsidian/optimizer/bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def _validate_hypers(self,
if hps.get(key) is None:
if not defaults['optional']:
raise ValueError(f'Must specify hyperpameter value {key} for {aq_str}')
if key in ['scalarization_weights', 'weights']:
if key in ['weights']: #['scalarization_weights', 'weights']:
aq_hps[key] = defaults['val'] * o_dim
else:
aq_hps[key] = defaults['val']
Expand Down Expand Up @@ -551,6 +551,13 @@ def _parse_aq_kwargs(self,
qmc_samples = draw_sobol_samples(bounds=X_bounds, n=128, q=m_batch)
aq_kwargs['mc_points'] = qmc_samples.squeeze(-2)

if aq == 'NParEGO':
w = hps['scalarization_weights']
if isinstance(w,list):
w = torch.tensor(w)
w = w/torch.sum(torch.abs(w))
aq_kwargs['scalarization_weights'] = w

return aq_kwargs

def suggest(self,
Expand Down

0 comments on commit fcf8220

Please sign in to comment.