-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
coding/decoding vector optimisation #8
Comments
I was wondering about the same, so tried to quickly write import numpy as np
import numba
__base32 = '0123456789bcdefghjkmnpqrstuvwxyz'
@numba.njit()
def encode_numba(latitude, longitude):
precision = 12
lat_interval = (-90.0, 90.0)
lon_interval = (-180.0, 180.0)
geohash = np.zeros(precision, dtype='<U1')
bits = np.array([16, 8, 4, 2, 1])
bit = 0
ch = 0
n = 0
even = True
while n < precision:
if even:
mid = (lon_interval[0] + lon_interval[1]) / 2
if longitude > mid:
ch |= bits[bit]
lon_interval = (mid, lon_interval[1])
else:
lon_interval = (lon_interval[0], mid)
else:
mid = (lat_interval[0] + lat_interval[1]) / 2
if latitude > mid:
ch |= bits[bit]
lat_interval = (mid, lat_interval[1])
else:
lat_interval = (lat_interval[0], mid)
even = not even
if bit < 4:
bit += 1
else:
geohash[n] = __base32[ch]
bit = 0
ch = 0
n += 1
return ''.join(geohash) This already provides a significant speedup from pygeohash import encode
%timeit encode(50, 14)
# 15.3 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit encode_numba(50, 14)
# 2.81 µs ± 84.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) It would make sense to possibly also use @numba.njit()
def encode_numba_int(latitude, longitude, precision=12):
lat_interval = (-90.0, 90.0)
lon_interval = (-180.0, 180.0)
geohash = np.zeros(precision, dtype='int') # <-- CHANGE HERE
bits = np.array([16, 8, 4, 2, 1])
bit = 0
ch = 0
n = 0
even = True
while n < precision:
if even:
mid = (lon_interval[0] + lon_interval[1]) / 2
if longitude > mid:
ch |= bits[bit]
lon_interval = (mid, lon_interval[1])
else:
lon_interval = (lon_interval[0], mid)
else:
mid = (lat_interval[0] + lat_interval[1]) / 2
if latitude > mid:
ch |= bits[bit]
lat_interval = (mid, lat_interval[1])
else:
lat_interval = (lat_interval[0], mid)
even = not even
if bit < 4:
bit += 1
else:
geohash[n] = ch # <-- CHANGE HERE
bit = 0
ch = 0
n += 1
return geohash # <-- CHANGE HERE This results in another speedup (we're down to about 1/25 of the original time) %timeit encode_numba_int(50, 14)
# 641 ns ± 49.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) One could of course change the code so that it doesn't even attempt to create an array with base-32 representation and merely constructs a single int. If the authors are interested in this direction, I'd be happy to start a PR (after specifying the resulting API a bit). I'd also be happy for more feedback on Numba, as I'm hardly an expert. |
There also appears to be some recent effort to refactor and speed up PyGeohash here: https://github.com/tastatham/gsoc_dask_geopandas_2021/issues/2 |
I've tried doing about the same with the decode function. It's a bit complicated since it uses dictionaries. I tried numba dictionaries but for some reason they only have setters and you can't get an item from them. So finally I made use of the @njit('int8(unicode_type)')
def base32_to_int(s):
res = ord(s) - 48
if res>9: res-=40
if res>16: res-=1
if res>18: res-=1
if res>20: res-=1
return res
@njit('UniTuple(float64, 4)(unicode_type)')
def numba_decode_exactly(geohash):
lat_interval_neg, lat_interval_pos, lon_interval_neg, lon_interval_pos = -90, 90, -180, 180
lat_err, lon_err = 90, 180
is_even = True
for c in geohash:
cd=base32_to_int(c)
for mask in (16, 8, 4, 2, 1):
if is_even: # adds longitude info
lon_err /= 2
if cd & mask:
lon_interval_neg = (lon_interval_neg + lon_interval_pos) / 2
else:
lon_interval_pos = (lon_interval_neg + lon_interval_pos) / 2
else: # adds latitude info
lat_err /= 2
if cd & mask:
lat_interval_neg = (lat_interval_neg + lat_interval_pos) / 2
else:
lat_interval_pos = (lat_interval_neg + lat_interval_pos) / 2
is_even = not is_even
lat = (lat_interval_neg + lat_interval_pos) / 2
lon = (lon_interval_neg + lon_interval_pos) / 2
return lat, lon, lat_err, lon_err
@njit('UniTuple(float64, 2)(unicode_type)')
def numba_decode(geohash):
"""
Decode geohash, returning two float with latitude and longitude
containing only relevant digits and with trailing zeroes removed.
"""
lat, lon, lat_err, lon_err = numba_decode_exactly(geohash)
# Format to the number of decimals that are known
lat_prec = max(1, int(round(-log10(lat_err))))
lon_prec = max(1, int(round(-log10(lon_err))))
lat = round(lat, lat_prec)
lon = round(lon, lon_prec)
return lat, lon I also modified the way precision is casted into strings then into floats. so that I could find a way to vectorize it after. The improvement in performance is relative to the precision/size of the geohashes so I fixed it to 12: geohash = ''.join(random.sample(__base32, 12)) %%timeit
decode(geohash)
# 19.9 µs ± 935 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) %%timeit
numba_decode(geohash)
# 4.37 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) |
update: I've made some modifications to speed processing of geohash arrays (an array with n geohashes) and pretty promissing (x10 speedup compared to a numpy vectorization). I don't know if the owners are still alive but I would love to start a PR. |
Would be happy to review a PR if you're still looking at this. |
I think it's worth doing as an optional path, like if installed as |
is there any way to optimize the code for faster decoding in particular on multiple geohashes?
The text was updated successfully, but these errors were encountered: