-
Notifications
You must be signed in to change notification settings - Fork 361
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
521 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.testing import RunnerTestCase | ||
from mmengine.tune._report_hook import ReportingHook | ||
from unittest.mock import MagicMock | ||
|
||
class TestReportingHook(RunnerTestCase): | ||
def test_append_score(self): | ||
hook = ReportingHook(monitor='acc', max_scoreboard_len=3) | ||
|
||
# Adding scores to the scoreboard | ||
hook._append_score(0.5) | ||
hook._append_score(0.6) | ||
hook._append_score(0.7) | ||
self.assertEqual(hook.scoreboard, [0.5, 0.6, 0.7]) | ||
|
||
# When exceeding max length, it should pop the first item | ||
hook._append_score(0.8) | ||
self.assertEqual(hook.scoreboard, [0.6, 0.7, 0.8]) | ||
|
||
def test_should_stop(self): | ||
runner = MagicMock(iter=3, epoch=1) | ||
|
||
# Test with tuning_iter | ||
hook1 = ReportingHook(monitor='acc', tuning_iter=5) | ||
self.assertFalse(hook1._should_stop(runner)) | ||
runner.iter = 4 | ||
self.assertTrue(hook1._should_stop(runner)) | ||
|
||
# Test with tuning_epoch | ||
hook2 = ReportingHook(monitor='acc', tuning_epoch=3) | ||
self.assertFalse(hook2._should_stop(runner)) | ||
runner.epoch = 2 | ||
self.assertTrue(hook2._should_stop(runner)) | ||
|
||
def test_report_score(self): | ||
hook1 = ReportingHook(monitor='acc', report_op='latest') | ||
hook1.scoreboard = [0.5, 0.6, 0.7] | ||
self.assertEqual(hook1.report_score(), 0.7) | ||
|
||
hook2 = ReportingHook(monitor='acc', report_op='mean') | ||
hook2.scoreboard = [0.5, 0.6, 0.7] | ||
self.assertEqual(hook2.report_score(), 0.6) | ||
|
||
# Test with an empty scoreboard | ||
hook3 = ReportingHook(monitor='acc', report_op='mean') | ||
self.assertIsNone(hook3.report_score()) | ||
|
||
def test_clear(self): | ||
hook = ReportingHook(monitor='acc') | ||
hook.scoreboard = [0.5, 0.6, 0.7] | ||
hook.clear() | ||
self.assertEqual(hook.scoreboard, []) | ||
|
||
def test_after_train_iter(self): | ||
runner = MagicMock(iter=3, epoch=1) | ||
runner.log_processor.get_log_after_iter = MagicMock(return_value=({'acc': 0.9}, 'log_str')) | ||
|
||
# Check if the monitored score gets appended correctly | ||
hook = ReportingHook(monitor='acc') | ||
hook.after_train_iter(runner, 0) | ||
self.assertEqual(hook.scoreboard[-1], 0.9) | ||
|
||
# Check if no score is appended for a non-existent metric | ||
hook2 = ReportingHook(monitor='non_existent') | ||
hook2.after_train_iter(runner, 0) | ||
self.assertEqual(len(hook2.scoreboard), 0) | ||
|
||
# Check that training stops if tuning_iter is reached | ||
runner.iter = 5 | ||
hook3 = ReportingHook(monitor='acc', tuning_iter=5) | ||
hook3.after_train_iter(runner, 0) | ||
self.assertTrue(runner.train_loop.stop_training) | ||
|
||
def test_after_val_epoch(self): | ||
runner = MagicMock(iter=3, epoch=1) | ||
|
||
# Check if the monitored score gets appended correctly from metrics | ||
metrics = {'acc': 0.9} | ||
hook = ReportingHook(monitor='acc') | ||
hook.after_val_epoch(runner, metrics=metrics) | ||
self.assertEqual(hook.scoreboard[-1], 0.9) | ||
|
||
# Check that no score is appended if the metric is missing from metrics | ||
metrics = {'loss': 0.1} | ||
hook2 = ReportingHook(monitor='acc') | ||
hook2.after_val_epoch(runner, metrics=metrics) | ||
self.assertEqual(len(hook2.scoreboard), 0) | ||
|
||
def test_with_runner(self): | ||
runner = self.build_runner(self.epoch_based_cfg) | ||
acc_hook = ReportingHook(monitor='test/acc', tuning_epoch=1) | ||
runner.register_hook(acc_hook, priority='VERY_LOW') | ||
runner.train() | ||
self.assertEqual(runner.epoch, 1) | ||
score = acc_hook.report_score() | ||
self.assertAlmostEqual(score, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from unittest import TestCase, skipIf | ||
import random | ||
from typing import List | ||
|
||
from mmengine.tune.searchers import NevergradSearcher | ||
|
||
try: | ||
import nevergrad | ||
NEVERGRAD_AVAILABLE = True | ||
except ImportError: | ||
NEVERGRAD_AVAILABLE = False | ||
|
||
@skipIf(not NEVERGRAD_AVAILABLE, "nevergrad is not installed") | ||
class TestNevergradSearcher(TestCase): | ||
def noisy_sphere_function(self, x: List[float]): | ||
"""Sphere function with noise: f(x) = sum(x_i^2) + noise""" | ||
noise = random.gauss(0, 0.1) # Gaussian noise with mean 0 and std 0.1 | ||
return sum([x_i ** 2 for x_i in x.values()]) + noise | ||
|
||
def one_max_function(self, x: List[int]): | ||
"""OneMax function: f(x) = sum(x_i) for binary x_i""" | ||
return sum(x) | ||
|
||
@property | ||
def target_solver_types(self): | ||
return [ | ||
'OnePlusOne', 'CMA', 'BO', 'DE', 'PSO', 'NGO' | ||
] | ||
|
||
def test_hash_dict(self): | ||
searcher = NevergradSearcher(rule='less', hparam_spec={}, num_trials=100, solver_type='OnePlusOne') | ||
|
||
# Check different dicts yield different hashes | ||
d1 = {"x": 1, "y": 2} | ||
d2 = {"x": 1, "y": 3} | ||
self.assertNotEqual(searcher._hash_dict(d1), searcher._hash_dict(d2)) | ||
|
||
# Check same dict yields same hash | ||
self.assertEqual(searcher._hash_dict(d1), searcher._hash_dict(d1)) | ||
|
||
# Check order doesn't matter | ||
d3 = {"y": 2, "x": 1} | ||
self.assertEqual(searcher._hash_dict(d1), searcher._hash_dict(d3)) | ||
|
||
def test_noisy_sphere_function(self): | ||
hparam_continuous_space = { | ||
'x1': { | ||
'type': 'continuous', | ||
'lower': -5.0, | ||
'upper': 5.0 | ||
}, | ||
'x2': { | ||
'type': 'continuous', | ||
'lower': -5.0, | ||
'upper': 5.0 | ||
} | ||
} | ||
for solver_type in self.target_solver_types: | ||
searcher = NevergradSearcher(rule='less', hparam_spec=hparam_continuous_space, num_trials=100, solver_type=solver_type) | ||
for _ in range(100): | ||
hparam = searcher.suggest() | ||
score = self.noisy_sphere_function([v for _,v in hparam.items()]) | ||
searcher.record(hparam, score) | ||
# For the noisy sphere function, the optimal should be close to x1=0 and x2=0 | ||
best_hparam = searcher.suggest() | ||
self.assertAlmostEqual(best_hparam['x1'], 0.0, places=1) | ||
self.assertAlmostEqual(best_hparam['x2'], 0.0, places=1) | ||
|
||
def test_one_max_function(self): | ||
# Define the discrete search space for OneMax | ||
hparam_discrete_space = { | ||
'x1': { | ||
'type': 'discrete', | ||
'values': [0, 1] | ||
}, | ||
'x2': { | ||
'type': 'discrete', | ||
'values': [0, 1] | ||
}, | ||
'x3': { | ||
'type': 'discrete', | ||
'values': [0, 1] | ||
}, | ||
'x4': { | ||
'type': 'discrete', | ||
'values': [0, 1] | ||
} | ||
} | ||
for solver_type in self.target_solver_types: | ||
searcher = NevergradSearcher(rule='greater', hparam_spec=hparam_discrete_space, num_trials=100, solver_type=solver_type) | ||
for _ in range(100): | ||
hparam = searcher.suggest() | ||
score = self.one_max_function([v for _,v in hparam.items()]) | ||
searcher.record(hparam, score) | ||
# For the OneMax function, the optimal solution is x1=x2=x3=x4=1 | ||
best_hparam = searcher.suggest() | ||
self.assertEqual(best_hparam['x1'], 1) | ||
self.assertEqual(best_hparam['x2'], 1) | ||
self.assertEqual(best_hparam['x3'], 1) | ||
self.assertEqual(best_hparam['x4'], 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from unittest import TestCase | ||
|
||
from mmengine.tune.searchers import Searcher | ||
|
||
class TestSearcher(TestCase): | ||
|
||
def test_rule(self): | ||
valid_hparam_spec_1 = { | ||
'lr': { | ||
'type': 'discrete', | ||
'values': [0.01, 0.02, 0.03] | ||
} | ||
} | ||
# Invalid cases | ||
with self.assertRaises(AssertionError): | ||
Searcher(rule='invalid_rule', hparam_spec=valid_hparam_spec_1) | ||
Searcher(rule='greater', hparam_spec=valid_hparam_spec_1) | ||
Searcher(rule='less', hparam_spec=valid_hparam_spec_1) | ||
|
||
def test_validate_hparam_spec(self): | ||
# Unknown hparam spec type | ||
invalid_hparam_spec_1 = { | ||
'lr': { | ||
'type': 'unknown_type', | ||
'values': [0.01, 0.02, 0.03] | ||
} | ||
} | ||
with self.assertRaises(AssertionError): | ||
Searcher(rule='greater', hparam_spec=invalid_hparam_spec_1) | ||
|
||
# Missing keys in continuous hparam_spec | ||
invalid_hparam_spec_2 = { | ||
'lr': { | ||
'type': 'continuous', | ||
'lower': 0.01 | ||
} | ||
} | ||
with self.assertRaises(AssertionError): | ||
Searcher(rule='less', hparam_spec=invalid_hparam_spec_2) | ||
|
||
# Invalid discrete hparam_spec | ||
invalid_hparam_spec_3 = { | ||
'lr': { | ||
'type': 'discrete', | ||
'values': [] # Empty list | ||
} | ||
} | ||
with self.assertRaises(AssertionError): | ||
Searcher(rule='greater', hparam_spec=invalid_hparam_spec_3) | ||
|
||
# Invalid continuous hparam_spec | ||
invalid_hparam_spec_4 = { | ||
'lr': { | ||
'type': 'continuous', | ||
'lower': 0.1, | ||
'upper': 0.01 # lower is greater than upper | ||
} | ||
} | ||
with self.assertRaises(AssertionError): | ||
Searcher(rule='less', hparam_spec=invalid_hparam_spec_4) | ||
|
||
# Invalid data type in continuous hparam_spec | ||
invalid_hparam_spec_5 = { | ||
'lr': { | ||
'type': 'continuous', | ||
'lower': '0.01', # String instead of number | ||
'upper': 0.1 | ||
} | ||
} | ||
with self.assertRaises(AssertionError): | ||
Searcher(rule='less', hparam_spec=invalid_hparam_spec_5) | ||
|
||
def test_hparam_spec_property(self): | ||
hparam_spec = { | ||
'lr': { | ||
'type': 'discrete', | ||
'values': [0.01, 0.02, 0.03] | ||
} | ||
} | ||
searcher = Searcher(rule='greater', hparam_spec=hparam_spec) | ||
self.assertEqual(searcher.hparam_spec, hparam_spec) | ||
|
||
def test_rule_property(self): | ||
searcher = Searcher(rule='greater', hparam_spec={}) | ||
self.assertEqual(searcher.rule, 'greater') |
Oops, something went wrong.