Skip to content

Commit

Permalink
Add unittest for tune
Browse files Browse the repository at this point in the history
  • Loading branch information
yhna940 committed Aug 30, 2023
1 parent 3418ddc commit 92ad439
Show file tree
Hide file tree
Showing 6 changed files with 521 additions and 8 deletions.
4 changes: 2 additions & 2 deletions mmengine/tune/_report_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def _should_stop(self, runner):
runner (Runner): The runner of the training process.
"""
if self.tuning_iter is not None:
if runner.iter > self.tuning_iter:
if runner.iter + 1 >= self.tuning_iter:
return True
elif self.tuning_epoch is not None:
if runner.epoch > self.tuning_epoch:
if runner.epoch + 1 >= self.tuning_epoch:
return True
else:
return False
Expand Down
16 changes: 10 additions & 6 deletions mmengine/tune/searchers/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,18 @@ def _validate_hparam_spec(self, hparam_spec: Dict[str, Dict]):
'hparam_spec must have a key "type" and ' \
f'its value must be "discrete" or "continuous", but got {v}'
if v['type'] == 'discrete':
assert 'values' in v, \
'if hparam_spec["type"] is "discrete", ' +\
f'hparam_spec must have a key "values", but got {v}'
assert 'values' in v and isinstance(v['values'], list) and \
v['values'], 'Expected a non-empty "values" list for ' + \
'discrete type, but got {v}'
else:
assert 'lower' in v and 'upper' in v, \
'if hparam_spec["type"] is "continuous", ' +\
'hparam_spec must have keys "lower" and "upper", ' +\
f'but got {v}'
'Expected keys "lower" and "upper" for continuous ' + \
f'type, but got {v}'
assert isinstance(v['lower'], (int, float)) and \
isinstance(v['upper'], (int, float)), \
f'Expected "lower" and "upper" to be numbers, but got {v}'
assert v['lower'] < v['upper'], \
f'Expected "lower" to be less than "upper", but got {v}'

@property
def hparam_spec(self) -> Dict[str, Dict]:
Expand Down
96 changes: 96 additions & 0 deletions tests/test_tune/test_report_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.testing import RunnerTestCase
from mmengine.tune._report_hook import ReportingHook
from unittest.mock import MagicMock

class TestReportingHook(RunnerTestCase):
def test_append_score(self):
hook = ReportingHook(monitor='acc', max_scoreboard_len=3)

# Adding scores to the scoreboard
hook._append_score(0.5)
hook._append_score(0.6)
hook._append_score(0.7)
self.assertEqual(hook.scoreboard, [0.5, 0.6, 0.7])

# When exceeding max length, it should pop the first item
hook._append_score(0.8)
self.assertEqual(hook.scoreboard, [0.6, 0.7, 0.8])

def test_should_stop(self):
runner = MagicMock(iter=3, epoch=1)

# Test with tuning_iter
hook1 = ReportingHook(monitor='acc', tuning_iter=5)
self.assertFalse(hook1._should_stop(runner))
runner.iter = 4
self.assertTrue(hook1._should_stop(runner))

# Test with tuning_epoch
hook2 = ReportingHook(monitor='acc', tuning_epoch=3)
self.assertFalse(hook2._should_stop(runner))
runner.epoch = 2
self.assertTrue(hook2._should_stop(runner))

def test_report_score(self):
hook1 = ReportingHook(monitor='acc', report_op='latest')
hook1.scoreboard = [0.5, 0.6, 0.7]
self.assertEqual(hook1.report_score(), 0.7)

hook2 = ReportingHook(monitor='acc', report_op='mean')
hook2.scoreboard = [0.5, 0.6, 0.7]
self.assertEqual(hook2.report_score(), 0.6)

# Test with an empty scoreboard
hook3 = ReportingHook(monitor='acc', report_op='mean')
self.assertIsNone(hook3.report_score())

def test_clear(self):
hook = ReportingHook(monitor='acc')
hook.scoreboard = [0.5, 0.6, 0.7]
hook.clear()
self.assertEqual(hook.scoreboard, [])

def test_after_train_iter(self):
runner = MagicMock(iter=3, epoch=1)
runner.log_processor.get_log_after_iter = MagicMock(return_value=({'acc': 0.9}, 'log_str'))

# Check if the monitored score gets appended correctly
hook = ReportingHook(monitor='acc')
hook.after_train_iter(runner, 0)
self.assertEqual(hook.scoreboard[-1], 0.9)

# Check if no score is appended for a non-existent metric
hook2 = ReportingHook(monitor='non_existent')
hook2.after_train_iter(runner, 0)
self.assertEqual(len(hook2.scoreboard), 0)

# Check that training stops if tuning_iter is reached
runner.iter = 5
hook3 = ReportingHook(monitor='acc', tuning_iter=5)
hook3.after_train_iter(runner, 0)
self.assertTrue(runner.train_loop.stop_training)

def test_after_val_epoch(self):
runner = MagicMock(iter=3, epoch=1)

# Check if the monitored score gets appended correctly from metrics
metrics = {'acc': 0.9}
hook = ReportingHook(monitor='acc')
hook.after_val_epoch(runner, metrics=metrics)
self.assertEqual(hook.scoreboard[-1], 0.9)

# Check that no score is appended if the metric is missing from metrics
metrics = {'loss': 0.1}
hook2 = ReportingHook(monitor='acc')
hook2.after_val_epoch(runner, metrics=metrics)
self.assertEqual(len(hook2.scoreboard), 0)

def test_with_runner(self):
runner = self.build_runner(self.epoch_based_cfg)
acc_hook = ReportingHook(monitor='test/acc', tuning_epoch=1)
runner.register_hook(acc_hook, priority='VERY_LOW')
runner.train()
self.assertEqual(runner.epoch, 1)
score = acc_hook.report_score()
self.assertAlmostEqual(score, 1)
101 changes: 101 additions & 0 deletions tests/test_tune/test_searchers/test_nevergrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase, skipIf
import random
from typing import List

from mmengine.tune.searchers import NevergradSearcher

try:
import nevergrad
NEVERGRAD_AVAILABLE = True
except ImportError:
NEVERGRAD_AVAILABLE = False

@skipIf(not NEVERGRAD_AVAILABLE, "nevergrad is not installed")
class TestNevergradSearcher(TestCase):
def noisy_sphere_function(self, x: List[float]):
"""Sphere function with noise: f(x) = sum(x_i^2) + noise"""
noise = random.gauss(0, 0.1) # Gaussian noise with mean 0 and std 0.1
return sum([x_i ** 2 for x_i in x.values()]) + noise

def one_max_function(self, x: List[int]):
"""OneMax function: f(x) = sum(x_i) for binary x_i"""
return sum(x)

@property
def target_solver_types(self):
return [
'OnePlusOne', 'CMA', 'BO', 'DE', 'PSO', 'NGO'
]

def test_hash_dict(self):
searcher = NevergradSearcher(rule='less', hparam_spec={}, num_trials=100, solver_type='OnePlusOne')

# Check different dicts yield different hashes
d1 = {"x": 1, "y": 2}
d2 = {"x": 1, "y": 3}
self.assertNotEqual(searcher._hash_dict(d1), searcher._hash_dict(d2))

# Check same dict yields same hash
self.assertEqual(searcher._hash_dict(d1), searcher._hash_dict(d1))

# Check order doesn't matter
d3 = {"y": 2, "x": 1}
self.assertEqual(searcher._hash_dict(d1), searcher._hash_dict(d3))

def test_noisy_sphere_function(self):
hparam_continuous_space = {
'x1': {
'type': 'continuous',
'lower': -5.0,
'upper': 5.0
},
'x2': {
'type': 'continuous',
'lower': -5.0,
'upper': 5.0
}
}
for solver_type in self.target_solver_types:
searcher = NevergradSearcher(rule='less', hparam_spec=hparam_continuous_space, num_trials=100, solver_type=solver_type)
for _ in range(100):
hparam = searcher.suggest()
score = self.noisy_sphere_function([v for _,v in hparam.items()])
searcher.record(hparam, score)
# For the noisy sphere function, the optimal should be close to x1=0 and x2=0
best_hparam = searcher.suggest()
self.assertAlmostEqual(best_hparam['x1'], 0.0, places=1)
self.assertAlmostEqual(best_hparam['x2'], 0.0, places=1)

def test_one_max_function(self):
# Define the discrete search space for OneMax
hparam_discrete_space = {
'x1': {
'type': 'discrete',
'values': [0, 1]
},
'x2': {
'type': 'discrete',
'values': [0, 1]
},
'x3': {
'type': 'discrete',
'values': [0, 1]
},
'x4': {
'type': 'discrete',
'values': [0, 1]
}
}
for solver_type in self.target_solver_types:
searcher = NevergradSearcher(rule='greater', hparam_spec=hparam_discrete_space, num_trials=100, solver_type=solver_type)
for _ in range(100):
hparam = searcher.suggest()
score = self.one_max_function([v for _,v in hparam.items()])
searcher.record(hparam, score)
# For the OneMax function, the optimal solution is x1=x2=x3=x4=1
best_hparam = searcher.suggest()
self.assertEqual(best_hparam['x1'], 1)
self.assertEqual(best_hparam['x2'], 1)
self.assertEqual(best_hparam['x3'], 1)
self.assertEqual(best_hparam['x4'], 1)
86 changes: 86 additions & 0 deletions tests/test_tune/test_searchers/test_searcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

from mmengine.tune.searchers import Searcher

class TestSearcher(TestCase):

def test_rule(self):
valid_hparam_spec_1 = {
'lr': {
'type': 'discrete',
'values': [0.01, 0.02, 0.03]
}
}
# Invalid cases
with self.assertRaises(AssertionError):
Searcher(rule='invalid_rule', hparam_spec=valid_hparam_spec_1)
Searcher(rule='greater', hparam_spec=valid_hparam_spec_1)
Searcher(rule='less', hparam_spec=valid_hparam_spec_1)

def test_validate_hparam_spec(self):
# Unknown hparam spec type
invalid_hparam_spec_1 = {
'lr': {
'type': 'unknown_type',
'values': [0.01, 0.02, 0.03]
}
}
with self.assertRaises(AssertionError):
Searcher(rule='greater', hparam_spec=invalid_hparam_spec_1)

# Missing keys in continuous hparam_spec
invalid_hparam_spec_2 = {
'lr': {
'type': 'continuous',
'lower': 0.01
}
}
with self.assertRaises(AssertionError):
Searcher(rule='less', hparam_spec=invalid_hparam_spec_2)

# Invalid discrete hparam_spec
invalid_hparam_spec_3 = {
'lr': {
'type': 'discrete',
'values': [] # Empty list
}
}
with self.assertRaises(AssertionError):
Searcher(rule='greater', hparam_spec=invalid_hparam_spec_3)

# Invalid continuous hparam_spec
invalid_hparam_spec_4 = {
'lr': {
'type': 'continuous',
'lower': 0.1,
'upper': 0.01 # lower is greater than upper
}
}
with self.assertRaises(AssertionError):
Searcher(rule='less', hparam_spec=invalid_hparam_spec_4)

# Invalid data type in continuous hparam_spec
invalid_hparam_spec_5 = {
'lr': {
'type': 'continuous',
'lower': '0.01', # String instead of number
'upper': 0.1
}
}
with self.assertRaises(AssertionError):
Searcher(rule='less', hparam_spec=invalid_hparam_spec_5)

def test_hparam_spec_property(self):
hparam_spec = {
'lr': {
'type': 'discrete',
'values': [0.01, 0.02, 0.03]
}
}
searcher = Searcher(rule='greater', hparam_spec=hparam_spec)
self.assertEqual(searcher.hparam_spec, hparam_spec)

def test_rule_property(self):
searcher = Searcher(rule='greater', hparam_spec={})
self.assertEqual(searcher.rule, 'greater')
Loading

0 comments on commit 92ad439

Please sign in to comment.