-
Notifications
You must be signed in to change notification settings - Fork 1
/
util.py
539 lines (473 loc) · 15.9 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
# util.py -- Generic utility functions and classes for FARGish
from collections.abc import Iterable
import collections
import random
import sys
from inspect import isclass
from dataclasses import dataclass, Field, fields, is_dataclass
from typing import Union, List, Tuple, Dict, Set, FrozenSet, Iterable, Any, \
NewType, Type, ClassVar, Sequence, Callable, Hashable, Collection, \
Sequence
from contextlib import AbstractContextManager
from types import SimpleNamespace
from itertools import chain, tee, filterfalse
import functools
empty_set: FrozenSet = frozenset()
newline = '\n'
backslash = '\\'
def is_iter(o):
return (
isinstance(o, Iterable)
and
not isinstance(o, str)
and
not is_namedtuple(o)
)
def is_namedtuple(o):
#HACK
return hasattr(o, '_fields')
def as_iter(o) -> Iterable:
'''Returns o in a form that the caller can iterate over. If o is already
an iterable (but not a string), returns o. If o is none, returns an
empty list. If o is anything else, returns a one-element list containing
o.'''
if isinstance(o, str):
return [o]
elif is_iter(o):
return o
elif o is None:
return []
else:
return [o]
def as_list(o) -> List:
if isinstance(o, list):
return o
else:
return list(as_iter(o))
def as_set(o):
'''Converts o to a set if it isn't a set already.'''
if isinstance(o, set):
return o
else:
return set(as_iter(o))
def as_hashable(o) -> Hashable:
'''Tries to make a Hashable object out of o if it's not. It will probably
fail.'''
if isinstance(o, collections.abc.Hashable):
# Warning: if o is a dataclass with frozen=True, it will be seen as
# Hashable but it won't be if it contains any unhashable members,
# like a list or dict.
return o
if isinstance(o, list):
o = tuple(o)
else:
raise ValueError(f'{o} cannot be made hashable. Maybe it contains a list.')
return o
# TODO UT
def omit(d: Dict, keys: Iterable) -> Dict:
'''Returns copy of 'd' with all members of 'keys' omitted.'''
return dict(
(k, v) for k, v in d.items() if k not in as_set(keys)
)
def d_subset(d: Dict, keys: Iterable) -> Dict:
'''Returns copy of 'd' with only members of 'keys'; all other items are
omitted.'''
return dict(
(k, v) for k, v in d.items() if k in as_set(keys)
)
def field_names(dclass) -> List[str]:
return [f.name for f in fields(dclass)]
def fields_for(dclass, kwargs: Dict[str, Any]) -> Dict[str, Any]:
return d_subset(kwargs, field_names(dclass))
def force_setattr(obj, attrname, value):
'''Writes value on obj.attrname even if obj is immutable.'''
object.__setattr__(obj, attrname, value)
#TODO UT
def loose_dict_eq(d1: Dict, d2: Dict) -> bool:
'''Are d1 and d2 equal, if we count a value of None as equal to not having
a value at all?'''
for k in set(chain(d1.keys(), d2.keys())):
if d1.get(k, None) != d2.get(k, None):
return False
return True
# TODO Rename this to something clearer
def tupdict(**kwargs) -> Tuple[Tuple[str, Hashable]]:
return tuple( # type: ignore
(k, v) for k, v in kwargs.items()
)
def as_dict(x: Union[Dict, None, Collection[Tuple[str, Hashable]]]) -> Dict:
# TODO Update type annotation to show that x can be a dataclass.
if isinstance(x, dict):
return x
elif x is None:
return dict()
elif is_dataclass_instance(x):
# dataclasses.asdict() fails on many objects because it recursively
# makes dictionaries for all of x's members. Here, we don't recurse.
return dict(
(name, getattr(x, name)) for name in field_names(x)
)
else:
return dict(x)
def asdict_with_classvars(x) -> Dict[str, Any]:
'''Does not recurse (see dataclasses._asdict_inner() for how to do that
right), and fails if x lacks a class variable declared in x's class
definition.'''
return dict(
(name, getattr(x, name))
for name in x.__dataclass_fields__
)
def as_name(x):
try:
return x.name
except AttributeError:
return str(x)
def is_dataclass_instance(x):
return is_dataclass(x) and not isinstance(x, type)
def short(x):
'''Returns a short string representation of x. If x has a .short() method
define, we call it and return its result. Otherwise we return str(x).'''
try:
return x.short()
except AttributeError:
return str(x)
def vcat(a, b):
'''Concatenate value(s). Combines a and b into either a list or a
single value, preferring the latter. If a is an iterable, modifies a
and returns it. But if a is None or a non-iterable, vcat creates a
new list and returns that. So, calling could should look like this:d
a = vcat(a, b).'''
if a is None:
return b
if b is None:
return a
if is_iter(a):
if is_iter(b):
a += b
else:
a.append(b)
return a
if is_iter(b):
return [a] + b
return [a, b]
def is_seq_of(x: Any, clas: Type) -> bool:
if isinstance(x, list) or isinstance(x, tuple):
try:
return isinstance(x[0], clas)
except IndexError:
return False
return False
# TODO rm (OAOO Node.py)
def is_nodeid(x):
return isinstance(x, int)
def reseed(seed=None):
'''With seed=None: reseeds Python's random-number generator with a new,
random seed based on the time, and returns it so you can save it.
Otherwise reseeds Python's random-number generator with seed, and again
returns seed.'''
if seed is None:
random.seed(None)
seed = random.randrange(sys.maxsize)
random.seed(seed)
return seed
def nice_object_repr(self):
'''Stick __repr__ = nice_object_repr inside a class definition and
repr() will return a nice string for most classes.'''
return repr_str(self.__class__.__name__, self.__dict__.items())
def repr_str(name, items):
'''items is iterable of (name, value). Returns the string appropriate for
repr().'''
if len(items) == 1:
return '%s(%s)' % (name, nrepr(next(iter(items))[1]))
elif len(items) == 0:
return name
else:
return '%s(%s)' % (name,
', '.join('%s=%s' % (k, nrepr(v))
for k, v in items))
def nrepr(o):
'''Helper for nice_object_repr().'''
if isclass(o):
return o.__name__
elif isinstance(o, float):
return '%.3f' % o
else:
return repr(o)
class NiceRepr:
"Mix-in to give descendants nice_object_repr."
__repr__ = nice_object_repr
def csep(xs) -> str:
'''Comma-separated string for whatever you pass it.'''
return ', '.join(str(x) for x in as_iter(xs))
def ssep(xs) -> str:
'''Space-separated string for whatever you pass it.'''
return ' '.join(str(x) for x in as_iter(xs))
def default_field_value(f: Field) -> Any:
if callable(f.default_factory): # type: ignore
return f.default_factory() # type: ignore
else:
return f.default
def rescale(xs, new_total=1.0):
'''Returns list of xs, rescaled to sum to new_total.'''
if not xs:
return xs
s = sum(xs)
if s == 0:
x = new_total / len(xs)
return [x] * len(xs)
else:
multiplier = new_total / sum(xs)
return [multiplier * x for x in xs]
def rescale_to_max(xs: Sequence[float]) -> Iterable[float]:
'''Returns xs, rescaled so that max(xs) == 1.0.'''
#TODO Deal with it in a nice way if max(xs) <= 0.0.'''
if not xs:
return xs
m = max(xs)
if m <= 0.0:
offset = 1.0 - m
for x in xs:
yield x + offset
else:
for x in xs:
yield x / m
def reweight(xs: Sequence[float], s: float) -> Iterable[float]:
'''See Ben's notes, 23-Dec-2019. 0.0 <= s <= 1.0. s small scales weights to
be all nearly 1.0 except for the very lowest ones. s large scales weights
to be be all nearly 0.0 except for the very highest ones. If s == 0.5,
the weights will be unchanged. 's' stands for sensitivity to the weights.'''
if not xs:
return []
p = 10 ** (4 * s - 2)
m = max(xs)
#print('REWxs', xs)
#print('REWrxs', list(rescale_to_max(xs)))
for x in rescale_to_max(xs):
#print('REWx', x, x ** p)
yield x ** p
# TODO BUG numpy has its own random-number generator. This results in
# indeterminism because other code invokes Python's random-number generator.
#def sample_without_replacement(items, k=1, weights=None):
# '''k is number of items to choose. If k > len(items), returns only
# len(items) items.'''
# if weights is not None:
# weights = rescale(weights)
# items = [x for (i, x) in enumerate(items) if weights[i] > 0.0]
# weights = [w for w in weights if w > 0.0]
# return np.random.choice(
# items, size=min(k, len(items)), replace=False, p=weights
# )
def sample_without_replacement(items, k=1, weights=None):
'''Returns a generator of k items sampled randomly without replacement from
'items', weighted by 'weights'. If 'weights' is None, then all items have
equal probability. If k > number of items, returns same result as if k =
number of items.'''
try:
n = len(items)
except TypeError:
items = list(items)
n = len(items)
if weights is None:
weights = [1.0] * n
weights = rescale(weights)
items = [x for (i, x) in enumerate(items) if weights[i] > 0.0]
weights = [w for w in weights if w > 0.0]
for i in range(k):
if items:
i = random.choices(range(len(items)), weights=weights, k=1)[0]
item = items.pop(i)
del weights[i]
yield item
else:
return
def read_to_blank_line(f):
result = ''
while True:
l = f.readline()
if not l.strip():
break
result += l
return result
def identity(x):
return x
def always_true(*args, **kwargs) -> bool:
return True
def always_false(*args, **kwargs) -> bool:
return False
def clip(lb, ub, x):
'''Passing None for lb and/ub gives that bound no effect.'''
if lb is not None and x <= lb:
return lb
elif ub is not None and x >= ub:
return ub
else:
return x
def pairwise(iterable):
iterable1, iterable2 = tee(iterable)
next(iterable2, None)
return zip(iterable1, iterable2)
def filter_none(f, iterable):
xs = [f(i) for i in iterable if i is not None]
return [x for x in xs if x is not None]
# TODO UT
def intersection(*iters: Iterable) -> Set:
'''Returns a set, which is the intersection of sets. The sets may be
any iterable, not just 'set' objects.'''
sets = [
s if isinstance(s, set) else set(s)
for s in iters
]
if sets:
return sets[0].intersection(*sets[1:])
else:
return set()
# TODO UT
def union(*sets: Iterable) -> Set:
return set().union(*sets)
def first(iterable):
'''Returns first element in iterable, or None if iterable is empty.'''
for x in as_iter(iterable):
return x
def first_non_none(iterable):
'''Returns the first non-None in iterable, or None if there isn't one.'''
for x in as_iter(iterable):
if x is not None:
return x
# Recipe from https://docs.python.org/3.7/library/itertools.html
def unique_everseen(iterable, key=None):
"List unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') --> A B C D
# unique_everseen('ABBCcAD', str.lower) --> A B C D
seen = set()
seen_add = seen.add
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen_add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen_add(k)
yield element
def input_integers(prompt):
'''Prompts the user to enter a list of integers separated by spaces.
Returns a list of numbers, or None if user just hit Enter.'''
while True:
s = input(prompt)
if not s:
return None
try:
ns = [int(n) for n in s.split()]
if not ns:
continue
return ns
except ValueError:
print('Please enter a list of integers separated by spaces.')
continue
def setattr_from_kwargs(o, kwargs, *attr_names):
for attr_name in attr_names:
try:
setattr(o, attr_name, kwargs[attr_name])
except KeyError:
raise(ValueError(
f'{o.__class__.__name__} ctor missing argument "{attr_name}".'
))
class ReprEq:
'''Mix-in to make a class's __eq__ and __hash__ work according to the
output of the class's __repr__.'''
def __eq__(self, other):
return repr(self) == repr(other)
def __hash__(self):
return hash(repr(self))
#TODO Redo with contextlib.contextmanager
@dataclass
class PushAttr(AbstractContextManager):
o: object #SimpleNamespace
attr_name: str
saved_value: Any = None
def __enter__(self):
self.saved_value = getattr(self.o, self.attr_name)
return self
def __exit__(self, *args, **kwargs):
setattr(self.o, self.attr_name, self.saved_value)
return None
@dataclass(frozen=True)
class Quote:
'''For when you want to hold a value in a way that you can
distinguished from values that you use to represent other things.'''
value: Any
@classmethod
def get(cls, x):
if isinstance(x, Quote):
return x.value
else:
return x
class ClassStrIsName(type):
def __str__(self):
return self.__name__
# Class decorators
def singleton(cls):
"""
By Siddhesh Suhas Sathe. Copied from:
https://github.com/siddheshsathe/handy-decorators/blob/master/src/decorators.py
Handy decorator for creating a singleton class
Description:
- Decorate your class with this decorator
- If you happen to create another instance of the same class, it will return the previously created one
- Supports creation of multiple instances of same class with different args/kwargs
- Works for multiple classes
Use:
>>> from decorators import singleton
>>>
>>> @singleton
... class A:
... def __init__(self, *args, **kwargs):
... pass
...
>>>
>>> a = A(name='Siddhesh')
>>> b = A(name='Siddhesh', lname='Sathe')
>>> c = A(name='Siddhesh', lname='Sathe')
>>> a is b # has to be different
False
>>> b is c # has to be same
True
>>>
"""
previous_instances = {}
@functools.wraps(cls)
def wrapper(*args, **kwargs):
if cls in previous_instances and previous_instances.get(cls, None).get('args') == (args, kwargs):
return previous_instances[cls].get('instance')
else:
previous_instances[cls] = {
'args': (args, kwargs),
'instance': cls(*args, **kwargs)
}
return previous_instances[cls].get('instance')
return wrapper
# Debugging
def pts(ls: Iterable, n=None):
'''Prints ls as a table of strings. For debugging.'''
for i, x in enumerate(as_iter(ls)):
if n is not None and i >= n:
break
if is_iter(x):
print(', '.join(str(y) for y in x))
else:
print(str(x))
def pl(x: Any):
'''Prints x as a list, one line at a time.'''
for a in as_iter(x):
print(a)
def pr(x: Any, *args, **kwargs):
'''Prints x as a list, one line at a time, alphabetized.'''
if hasattr(x, 'pr'):
x.pr(*args, **kwargs)
elif isinstance(x, dict):
pts(sorted(x.items(), key=str))
else:
pts(sorted(as_iter(x), key=str))
# for s in sorted(str(a) for a in as_iter(x)):
# print(s)