diff --git a/python/mozbuild/mozbuild/test/test_containers.py b/python/mozbuild/mozbuild/test/test_containers.py index f1c77404400b8..9eb7df1ded7ef 100644 --- a/python/mozbuild/mozbuild/test/test_containers.py +++ b/python/mozbuild/mozbuild/test/test_containers.py @@ -7,11 +7,17 @@ from mozunit import main from mozbuild.util import ( + KeyedDefaultDict, List, + OrderedDefaultDict, ReadOnlyDefaultDict, ReadOnlyDict, + ReadOnlyKeyedDefaultDict, ) +from collections import OrderedDict + + class TestReadOnlyDict(unittest.TestCase): def test_basic(self): original = {'foo': 1, 'bar': 2} @@ -108,5 +114,83 @@ def test_none(self): with self.assertRaises(ValueError): test = test + False +class TestOrderedDefaultDict(unittest.TestCase): + def test_simple(self): + original = OrderedDict(foo=1, bar=2) + + test = OrderedDefaultDict(bool, original) + + self.assertEqual(original, test) + + self.assertEqual(test['foo'], 1) + + self.assertEqual(test.keys(), ['foo', 'bar' ]) + + def test_defaults(self): + test = OrderedDefaultDict(bool, {'foo': 1 }) + + self.assertEqual(test['foo'], 1) + + self.assertEqual(test['qux'], False) + + self.assertEqual(test.keys(), ['foo', 'qux' ]) + + +class TestKeyedDefaultDict(unittest.TestCase): + def test_simple(self): + original = {'foo': 1, 'bar': 2 } + + test = KeyedDefaultDict(lambda x: x, original) + + self.assertEqual(original, test) + + self.assertEqual(test['foo'], 1) + + def test_defaults(self): + test = KeyedDefaultDict(lambda x: x, {'foo': 1 }) + + self.assertEqual(test['foo'], 1) + + self.assertEqual(test['qux'], 'qux') + + self.assertEqual(test['bar'], 'bar') + + test['foo'] = 2 + test['qux'] = None + test['baz'] = 'foo' + + self.assertEqual(test['foo'], 2) + + self.assertEqual(test['qux'], None) + + self.assertEqual(test['baz'], 'foo') + + +class TestReadOnlyKeyedDefaultDict(unittest.TestCase): + def test_defaults(self): + test = ReadOnlyKeyedDefaultDict(lambda x: x, {'foo': 1 }) + + self.assertEqual(test['foo'], 1) + + self.assertEqual(test['qux'], 'qux') + + self.assertEqual(test['bar'], 'bar') + + copy = dict(test) + + with self.assertRaises(Exception): + test['foo'] = 2 + + with self.assertRaises(Exception): + test['qux'] = None + + with self.assertRaises(Exception): + test['baz'] = 'foo' + + self.assertEqual(test, copy) + + self.assertEqual(len(test), 3) + + if __name__ == '__main__': main() diff --git a/python/mozbuild/mozbuild/util.py b/python/mozbuild/mozbuild/util.py index eeb08075c7b00..e2b0bd478ac87 100644 --- a/python/mozbuild/mozbuild/util.py +++ b/python/mozbuild/mozbuild/util.py @@ -76,13 +76,10 @@ def __init__(self, default_factory, *args, **kwargs): ReadOnlyDict.__init__(self, *args, **kwargs) self._default_factory = default_factory - def __getitem__(self, key): - try: - return ReadOnlyDict.__getitem__(self, key) - except KeyError: - value = self._default_factory() - dict.__setitem__(self, key, value) - return value + def __missing__(self, key): + value = self._default_factory() + dict.__setitem__(self, key, value) + return value def ensureParentDir(path): @@ -715,12 +712,26 @@ def __init__(self, default_factory, *args, **kwargs): OrderedDict.__init__(self, *args, **kwargs) self._default_factory = default_factory - def __getitem__(self, key): - try: - return OrderedDict.__getitem__(self, key) - except KeyError: - value = self[key] = self._default_factory() - return value + def __missing__(self, key): + value = self[key] = self._default_factory() + return value + + +class KeyedDefaultDict(dict): + '''Like a defaultdict, but the default_factory function takes the key as + argument''' + def __init__(self, default_factory, *args, **kwargs): + dict.__init__(self, *args, **kwargs) + self._default_factory = default_factory + + def __missing__(self, key): + value = self._default_factory(key) + dict.__setitem__(self, key, value) + return value + + +class ReadOnlyKeyedDefaultDict(KeyedDefaultDict, ReadOnlyDict): + '''Like KeyedDefaultDict, but read-only.''' class memoize(dict):