Two functions using NumPy to compute all bit strings of a given size.
Notice the difference in compute time:
%timeit all_bitstrings(10)
# 94.9 µs ± 491 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit all_bitstrings_slow(10)
# 3.03 ms ± 51.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
This neat way to get an iterator was pointed out to me by @steve_quantum
%timeit list(all_bitstrings_iterator(10))
# 33.1 µs ± 88.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
all_bitstrings_jax
uses JIT compilation in JAX. For larger sizes this gives a significant speedup.
%timeit all_bitstrings(24)
# 6.85 s ± 344 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit list(all_bitstrings_iterator(24))
# 9.21 s ± 262 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit all_bitstrings_jax(24).block_until_ready()
# 1.61 s ± 252 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
However, the JAX version is able to utilize about 200% of CPU (two cores), while the other functions run on strictly one core.