-
Notifications
You must be signed in to change notification settings - Fork 8
/
minibelt.py
executable file
·865 lines (636 loc) · 26 KB
/
minibelt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
# -*- coding: utf-8 -*-
"""
One-file utility module filled with helper functions for day to day Python
programming
This is a subset of batbelt, with only the most used features, packed in a tiny
file so you can just drop it in your project and forget about it.
It's under zlib licence.
"""
from __future__ import unicode_literals
import os
import re
import sys
import json
import unicodedata
import codecs
from itertools import islice, chain
from collections import deque
from datetime import datetime, timedelta, date, time
try:
from collections.abc import MutableSet
except ImportError:
from collections import MutableSet
__version__ = '0.2.2'
__all__ = [
'slugify', 'normalize', 'json_dumps', 'json_loads', 'CLASSIC_DATETIME_FORMAT',
'to_timestamp', 'import_from_path', 'attr', 'chunks', 'window', 'dmerge',
'get', 'subdict', 'iget', 'skip_duplicates', 'sset', 'unpack',
'add_to_pythonpath', 'write', 'flatten'
]
CLASSIC_DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S.%f'
CLASSIC_DATETIME_PATTERN = r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}'
# for P3K compat
try:
unicode
except NameError:
unicode = str
xrange = range
try:
import unidecode
def slugify(string, separator=r'-'):
r"""
Slugify a unicode string using unidecode to normalize the string.
:Example:
>>> slugify(u"H\xe9ll\xf8 W\xf3rld")
'hello-world'
>>> slugify("Bonjour, tout l'monde !", separator="_")
'bonjour_tout_lmonde'
>>> slugify("\tStuff with -- dashes and... spaces \n")
'stuff-with-dashes-and-spaces'
"""
string = normalize(string)
string = re.sub(r'[^\w\s' + separator + ']', '', string, flags=re.U)
string = string.strip().lower()
return re.sub(r'[' + separator + '\s]+', separator, string, flags=re.U)
def normalize(string):
r"""
Returns a new string withou non ASCII characters, trying to replace
them with their ASCII closest counter parts when possible.
:Example:
>>> normalize(u"H\xe9ll\xf8 W\xf3rld")
'Hello World'
This version use unidecode and provide enhanced results.
"""
return unidecode.unidecode(string)
except ImportError:
def normalize(string):
r"""
Returns a new string withou non ASCII characters, trying to replace
them with their ASCII closest counter parts when possible.
:Example:
>>> normalize(u"H\xe9ll\xf8 W\xc3\xb6rld")
'Hell World'
This version use unicodedata and provide limited yet
useful results.
"""
string = unicodedata.normalize('NFKD', string).encode('ascii', 'ignore')
return string.decode('ascii')
def slugify(string, separator=r'-'):
r"""
Slugify a unicode string using unicodedata to normalize the string.
:Example:
>>> slugify(u"H\xe9ll\xf8 W\xc3\xb6rld")
'hell-world'
>>> slugify("Bonjour, tout l'monde !", separator="_")
'bonjour_tout_lmonde'
>>> slugify("\tStuff with -- dashes and... spaces \n")
'stuff-with-dashes-and-spaces'
"""
string = normalize(string)
string = re.sub(r'[^\w\s' + separator + ']', '', string, flags=re.U)
string = string.strip().lower()
return re.sub(r'[' + separator + '\s]+', separator, string, flags=re.U)
class JSONEncoder(json.JSONEncoder):
"""
Json encoder with date and time handling.
You should use naive datetime only. If you have timezone information,
store them in a separate field.
"""
DATETIME_FORMAT = CLASSIC_DATETIME_FORMAT
DATE_FORMAT, TIME_FORMAT = DATETIME_FORMAT.split()
TIMEDELTA_FORMAT = "timedelta(seconds='%s')"
def __init__(self, datetime_format=None, date_format=None, time_format=None,
timedelta_format=None, *args, **kwargs):
self.datetime_format = datetime_format or self.DATETIME_FORMAT
self.date_format = date_format or self.DATE_FORMAT
self.time_format = time_format or self.TIME_FORMAT
self.timedelta_format = timedelta_format or self.TIMEDELTA_FORMAT
super(JSONEncoder, self).__init__(*args, **kwargs)
def default(self, obj):
if isinstance(obj, datetime):
return obj.strftime(self.datetime_format)
if isinstance(obj, date):
return obj.strftime(self.date_format)
if isinstance(obj, time):
return obj.strftime(self.time_format)
if isinstance(obj, timedelta):
return self.timedelta_format % obj.total_seconds()
return json.JSONEncoder.default(self, obj)
class JSONDecoder(json.JSONDecoder):
"""
Json decoder that decode JSON encoded with JSONEncoder
"""
DATETIME_PATTERN = CLASSIC_DATETIME_PATTERN
DATE_PATTERN, TIME_PATTERN = DATETIME_PATTERN.split()
TIMEDELTA_PATTERN = r"timedelta\(seconds='(?P<seconds>\d+(?:\.\d+)*)'\)"
def __init__(self, datetime_pattern=None, date_pattern=None,
time_pattern=None, timedelta_pattern=None, datetime_format=None,
date_format=None, time_format=None, *args, **kwargs):
self.datetime_format = datetime_format or JSONEncoder.DATETIME_FORMAT
self.date_format = date_format or JSONEncoder.DATE_FORMAT
self.time_format = time_format or JSONEncoder.TIME_FORMAT
self.datetime_pattern = re.compile(datetime_pattern or self.DATETIME_PATTERN)
self.date_pattern = re.compile(date_pattern or self.DATE_PATTERN)
self.time_pattern = re.compile(time_pattern or self.TIME_PATTERN)
self.timedelta_pattern = re.compile(timedelta_pattern or self.TIMEDELTA_PATTERN)
super(JSONDecoder, self).__init__(object_pairs_hook=self.object_pairs_hook,
*args, **kwargs)
def object_pairs_hook(self, obj):
return dict((k, self.decode_on_match(v)) for k, v in obj)
def decode_on_match(self, obj):
"""
Try to match the string, and if it fits any date format,
parse it and returns a Python object.
"""
string = unicode(obj)
match = re.search(self.datetime_pattern, string)
if match:
return datetime.strptime(match.string, self.datetime_format)
match = re.search(self.date_pattern, string)
if match:
return datetime.strptime(match.string, self.date_format).date()
match = re.search(self.time_pattern, string)
if match:
return datetime.strptime(match.string, self.time_format).time()
match = re.search(self.timedelta_pattern, string)
if match:
return timedelta(seconds=float(match.groupdict()['seconds']))
return obj
def json_dumps(data, datetime_format=None, date_format=None, time_format=None,
timedelta_format=None, *args, **kwargs):
r"""
Same as Python's json.dumps but also serialize datetime, date, time
and timedelta.
Example:
>>> import datetime
>>> json_dumps({'test': datetime.datetime(2000, 1, 1, 1, 1, 1)})
'{"test": "2000-01-01 01:01:01.000000"}'
>>> json_dumps({'test': datetime.date(2000, 1, 1)})
'{"test": "2000-01-01"}'
>>> json_dumps({'test': datetime.time(1, 1, 1)})
'{"test": "01:01:01.000000"}'
>>> json_dumps({'test': datetime.timedelta(1, 1)})
'{"test": "timedelta(seconds=\'86401.0\')"}'
>>> json_dumps({'test': datetime.timedelta(1, 1), 'a': [1, 2]})
'{"test": "timedelta(seconds=\'86401.0\')", "a": [1, 2]}'
"""
return JSONEncoder(datetime_format, date_format, time_format,
timedelta_format, *args, **kwargs).encode(data)
def json_loads(string, datetime_pattern=None, date_pattern=None,
time_pattern=None, timedelta_pattern=None, datetime_format=None,
date_format=None, time_format=None, *args, **kwargs):
r"""
Same as Python's json.loads, but handles formats from batbelt.json_dumps
which are currently mainly date formats.
Example:
>>> json_loads('{"test": "2000-01-01 01:01:01.000000"}')
{'test': datetime.datetime(2000, 1, 1, 1, 1, 1)}
>>> json_loads('{"test": "2000-01-01"}')
{'test': datetime.date(2000, 1, 1)}
>>> json_loads('{"test": "01:01:01.000000"}')
{'test': datetime.time(1, 1, 1)}
>>> json_loads('{"test": "timedelta(seconds=\'86401.0\')"}')
{'test': datetime.timedelta(1, 1)}
>>> json_loads('{"test": "timedelta(seconds=\'86401.0\')", "a": [1, 2]}')
{'test': datetime.timedelta(1, 1), 'a': [1, 2]}
"""
return JSONDecoder(datetime_pattern, date_pattern, time_pattern,
timedelta_pattern, datetime_format, date_format,
time_format, *args, **kwargs).decode(string)
def import_from_path(path):
"""
Import a class dynamically, given it's dotted path.
"""
module_name, class_name = path.rsplit('.', 1)
try:
return getattr(__import__(module_name, fromlist=[class_name]), class_name)
except AttributeError:
raise ImportError('Unable to import %s' % path)
def attr(obj, *attrs, **kwargs):
"""
Follow chained attributes and get the value of the last attributes.
If an attribute error is raised, returns the default value.
res = attr(data, 'test', 'o', 'bla', default="yeah")
is the equivalent of
try:
res = getattr(getattr(getattr(data, 'test'), 'o'), 'bla')
except AttributeError:
res = "yeah"
"""
try:
value = getattr(obj, attrs[0])
for attr in attrs[1:]:
value = getattr(value, attr)
except (IndexError, AttributeError):
return kwargs.get('default', None)
return value
def chunks(seq, chunksize, process=tuple):
"""
Yields items from an iterator in iterable chunks.
"""
it = iter(seq)
while True:
yield process(chain([next(it)], islice(it, chunksize - 1)))
def window(iterable, size=2, cast=tuple):
"""
Yields iterms by bunch of a given size, but rolling only one item
in and out at a time when iterating.
>>> list(window([1, 2, 3]))
[(1, 2), (2, 3)]
By default, this will cast the window to a tuple before yielding it;
however, any function that will accept an iterable as its argument
is a valid target.
If you pass None as a cast value, the deque will be returned as-is,
which is more performant. However, since only one deque is used
for the entire iteration, you'll get the same reference everytime,
only the deque will contains different items. The result might not
be what you want :
>>> list(window([1, 2, 3], cast=None))
[deque([2, 3], maxlen=2), deque([2, 3], maxlen=2)]
"""
iterable = iter(iterable)
d = deque(islice(iterable, size), size)
if cast:
yield cast(d)
for x in iterable:
d.append(x)
yield cast(d)
else:
yield d
for x in iterable:
d.append(x)
yield d
def dmerge(d1, d2, merge_func=None):
"""
Create a new dictionary being the merge of the two passed as a
parameter. If a key is in both dictionaries, the values are processed
with the merge_func.
By default the value in the second dictionary erases the value in the
first one.
"""
d = {}
d.update(d1)
if merge_func is None:
d.update(d2)
return d
for k, v in d2.iteritems():
if k in d:
d[k] = merge_func(d[k], v)
else:
d[k] = v
return d
def get(data, *keys, **kwargs):
"""
Extract a data from nested mapping and sequences using a list of keys
and indices to apply successively. If a key error or an index error
is raised, returns the default value.
res = get(data, 'test', 0, 'bla', default="yeah")
is the equivalent of
try:
res = data['test'][0]['bla']
except (KeyError, IndexError):
res = "yeah"
"""
try:
value = data[keys[0]]
for key in keys[1:]:
value = value[key]
except (KeyError, IndexError, TypeError):
return kwargs.get('default', None)
return value
def subdict(dct, include=(), exclude=()):
"""
Return a dictionary that is a copy of the given one.
All values in `include` are used as key to be copied to
the resulting dictionary.
You can also pass a list of key to exclude instead by setting
`exclude`. But you can't use both `include` and `exclude`: if you do,
`exclude will be ignored`
Example:
>>> subdict({1:None, 2: False, 3: True}, [1, 2])
{1: None, 2: False}
>>> subdict({1:None, 2: False, 3: True}, exclude=[1, 2])
{3: True}
"""
if include:
return dict((k, v) for k, v in dct.items() if k in include)
return dict((k, v) for k, v in dct.items() if k not in exclude)
def iget(data, value, default=None):
"""
Same as indexing, but works with any iterable,
including generators, and accept a default value.
:Example:
>>> iget(xrange(10), 0)
0
>>> iget(xrange(10), 5)
5
>>> iget(xrange(10), 10000, default='wololo')
'wololo'
It works with negative indices as well :
>>> iget(xrange(10), -3)
7
>>> iget(xrange(10), -1)
9
>>> iget(xrange(10), -10000, default='wololo')
'wololo'
Remember it has to consume the generator to get its elements so be careful
if you need an element at the end of it, you will empty your generator.
Also if you pass an infinite generator and ask for a negative value,
it will hang forever. Use itertools.islice to be sure your generator
will be finite when in doubt.
"""
if value >= 0:
for x in islice(data, value, None):
return x
return default
else:
value = abs(value)
d = deque((), value)
for elem in data:
d.append(elem)
if len(d) == value:
return d.popleft()
return default
def unpack(indexable, *args, **kwargs):
"""
Return an generator with the values for the given keys/indices or
a default value.
:Example:
>>> dct = {'a': 2, 'b': 4, 'z': 42}
>>> a, b, c = unpack(dct, 'a', 'b', 'c', default=1)
>>> a
2
>>> b
4
>>> c
1
>>> list(unpack(range(5, 10), 2, 4))
[7, 9]
"""
default = kwargs.get('default', None)
for key in args:
yield get(indexable, key, default=default)
def skip_duplicates(iterable, key=lambda x: x):
"""
Returns a generator that will yield all objects from iterable, skipping
duplicates.
Duplicates are identified using the `key` function to calculate a
unique fingerprint. This does not use natural equality, but the
result use a set() to remove duplicates, so defining __eq__
on your objects would have effect.
By default the fingerprint is the object itself,
which ensure the functions works as-is with iterable of primitives
such as int, str or tuple.
:Example:
>>> list(skip_duplicates([1, 2, 3, 4, 4, 2, 1, 3 , 4]))
[1, 2, 3, 4]
The return value of `key` MUST be hashable, which means for
non hashable objects such as dict, set or list, you need to specify
a a function that returns a hashable fingerprint.
:Example:
>>> list(skip_duplicates(([], [], (), [1, 2], (1, 2)), lambda x: tuple(x)))
[[], [1, 2]]
>>> list(skip_duplicates(([], [], (), [1, 2], (1, 2)), lambda x: (type(x), tuple(x))))
[[], (), [1, 2], (1, 2)]
For more complex types, such as custom classes, the default behavior
is to remove nothing. You MUST provide a `key` function is you wish
to filter those.
:Example:
>>> class Test(object):
... def __init__(self, foo='bar'):
... self.foo = foo
... def __repr__(self):
... return "Test('%s')" % self.foo
>>> list(skip_duplicates([Test(), Test(), Test('other')]))
[Test('bar'), Test('bar'), Test('other')]
>>> list(skip_duplicates([Test(), Test(), Test('other')], lambda x: x.foo))
[Test('bar'), Test('other')]
See also :
- strip_duplicates : a simpler, slower function that returns a list
of elements with no duplicates. It accepts
non hashable elements and honors __eq__.
- remove_duplicates : remove duplicates from a list in place.
Most ressource efficient merthod.
"""
fingerprints = set()
try:
# duplicate some code to gain perf in the most common case
if key is None:
for x in iterable:
if x not in fingerprints:
yield x
fingerprints.add(x)
else:
for x in iterable:
fingerprint = key(x)
if fingerprint not in fingerprints:
yield x
fingerprints.add(fingerprint)
except TypeError:
try:
hash(fingerprint)
except TypeError:
raise TypeError(
"Calculating the key on one element resulted in a non hashable "
"object of type '%s'. Change the 'key' parameter to a function "
"that always, returns a hashable object. Hint : primitives "
"like int, str or tuple, are hashable, dict, set and list are "
"not. \nThe object that triggered the error was:\n%s" % (
type(fingerprint), x)
)
else:
raise
KEY, PREV, NEXT = range(3)
class sset(MutableSet):
"""
Set that preserves ordering.
From http://code.activestate.com/recipes/576694/
"""
def __init__(self, iterable=None):
self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
def __len__(self):
return len(self.map)
def __contains__(self, key):
return key in self.map
def add(self, key):
if key not in self.map:
end = self.end
curr = end[PREV]
curr[NEXT] = end[PREV] = self.map[key] = [key, curr, end]
def discard(self, key):
if key in self.map:
key, prev, next = self.map.pop(key)
prev[NEXT] = next
next[PREV] = prev
def __iter__(self):
end = self.end
curr = end[NEXT]
while curr is not end:
yield curr[KEY]
curr = curr[NEXT]
def __reversed__(self):
end = self.end
curr = end[PREV]
while curr is not end:
yield curr[KEY]
curr = curr[PREV]
def pop(self, last=True):
if not self:
raise KeyError('set is empty')
key = next(reversed(self)) if last else next(iter(self))
self.discard(key)
return key
def __repr__(self):
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, list(self))
def __eq__(self, other):
if isinstance(other, sset):
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)
def __del__(self):
self.clear() # remove circular references
def add_to_pythonpath(path, starting_point='.', insertion_index=None):
"""
Add the directory to the sys.path.
You can path an absolute or a relative path to it.
If you choose to use a relative path, it will be relative to
`starting_point` by default, which is set to '.'.
You may want to set it to something like __file__ (the basename will
be stripped, and the current file's parent directory will be used
as a starting point, which is probably what you expect in the
first place).
:example:
>>> add_to_pythonpath('../..', __file__)
"""
if not os.path.isabs(path):
if os.path.isfile(starting_point):
starting_point = os.path.dirname(starting_point)
path = os.path.join(starting_point, path)
path = os.path.realpath(os.path.expandvars(os.path.expanduser(path)))
if path not in sys.path:
if insertion_index is None:
sys.path.append(path)
else:
sys.path.insert(insertion_index, path)
def write(path, *args, **kwargs):
r"""
Try to write to the file at `path` the values passed as `args` as lines.
It will attempt decoding / encoding and casting automatically each value
to a string.
This is an utility function : its slow and doesn't consider edge cases,
but allow to do just what you want most of the time in one line.
:Example:
s = '/tmp/test'
write(s, 'test', '\xe9', 1, ['fdjskl'])
print open(s).read()
test
\xe9
1
['fdjskl']
You can optionally pass :
mode : among 'a', 'w', which default to 'w'. Binary mode is forced.
encoding : which default to utf8 and will condition decoding AND encoding
errors : what to do when en encoding error occurs : 'replace' by default,
which replace faulty caracters with '?'
You can pass string or unicode as *args, but if you pass strings,
make sure you pass them with the same encoding you wish to write to
the file.
"""
mode = kwargs.get('mode', 'w')
encoding = kwargs.get('encoding', 'utf8')
errors = kwargs.get('encoding', 'replace')
with codecs.open(path, mode=mode, encoding=encoding, errors=errors) as f:
for line in args:
if isinstance(line, bytes):
line = line.decode(encoding, errors)
if not isinstance(line, unicode):
line = repr(line)
f.write(line + os.linesep)
class Flattener(object):
"""
Create a flattener that you can call on a deeply nested data
structures to iterate over the items as it if it were a flat iterable.
The flattener returns a generator that lazily yield the items and
deals with up to hundred of levels of nesting (~800 on my machine,
and you can control it with sys.setrecursionlimit).
A default flattener named 'flatten' is available by default.
:Example:
a = []
for i in range(10):
a = [a, i]
print(a)
[[[[[[[[[[[], 0], 1], 2], 3], 4], 5], 6], 7], 8], 9]
print(list(flatten(a)))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
By default, it flattens all the types listed in
Flattener.DEFAULT_FLATTEN_TYPES but you can pass you list via
flatten_types while calling a Flatener instance.
For ambigious types like dict, you can pass iterable_getters, a
mapping type / callback letting you define how to extract items from
each type.
:Example:
a = []
for i in range(2):
a = [a, i] + [{'a': 1., 'b': {'c': 3.}}]
print(a)
[[[], 0, {'a': 1.0, 'b': {'c': 3.0}}], 1, {'a': 1.0, 'b': {'c': 3.0}}]
new_ft = Flattener.DEFAULT_FLATTEN_TYPES + (dict,)
dico_flatten = Flattener(flatten_types=new_ft,
iterable_getters={dict: lambda x: x.items()})
print(list(dico_flatten(a)))
[0, u'a', 1.0, u'b', u'c', 3.0, 1, u'a', 1.0, u'b', u'c', 3.0]
"""
DEFAULT_FLATTEN_TYPES = (
list,
tuple,
set,
(x for x in ()).__class__,
xrange,
deque,
MutableSet,
# Sequence # warning, a string is a subclass of Sequence
)
def __init__(self, flatten_types=None, iterable_getters={}):
self.flatten_types = flatten_types or self.DEFAULT_FLATTEN_TYPES
self.iterable_getters = iterable_getters
def should_flatten(self, obj):
"""
Returns if the object should be flatten or not, checking if the
objects is an instance of type listed in DEFAULT_FLATTEN_TYPES
by default.
"""
return isinstance(obj, self.flatten_types)
def transform_iterable(self, obj):
"""
Apply a pre-processing to an object before iterate on it. Can
be useful for types such as dict on which you may want to call
values() or items() before iteration.
By defaut, it check if the object is an DIRECT instance (not
a subclass) of any key in iterable_getters, passed in __init__
and apply the transform.
iterable_getter should be a mapping with types as key and
transformation function as values, such as :
{dict: lambda x: x.items()}
iterable_getter default value is {}, making transform_iterable
a noop.
"""
if obj.__class__ in self.iterable_getters:
return self.iterable_getters[obj.__class__](obj)
return obj
def __call__(self, iterable):
"""
Returns a generator yieling items from a deeply nested iterable
like it would be a flat one.
"""
for e in iterable:
if self.should_flatten(e):
for f in self(self.transform_iterable(e)):
yield f
else:
yield e
flatten = Flattener()