Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coding/decoding vector optimisation #8

Open
IlyasMoutawwakil opened this issue Aug 11, 2021 · 7 comments
Open

coding/decoding vector optimisation #8

IlyasMoutawwakil opened this issue Aug 11, 2021 · 7 comments

Comments

@IlyasMoutawwakil
Copy link
Contributor

is there any way to optimize the code for faster decoding in particular on multiple geohashes?

@DahnJ
Copy link

DahnJ commented Aug 13, 2021

I was wondering about the same, so tried to quickly write encode in Numba

import numpy as np
import numba

__base32 = '0123456789bcdefghjkmnpqrstuvwxyz'

@numba.njit()
def encode_numba(latitude, longitude):
    precision = 12
    lat_interval = (-90.0, 90.0)
    lon_interval = (-180.0, 180.0)
    geohash = np.zeros(precision, dtype='<U1')
    bits = np.array([16, 8, 4, 2, 1])
    bit = 0
    ch = 0
    n = 0
    even = True
    while n < precision:
        if even:
            mid = (lon_interval[0] + lon_interval[1]) / 2
            if longitude > mid:
                ch |= bits[bit]
                lon_interval = (mid, lon_interval[1])
            else:
                lon_interval = (lon_interval[0], mid)
        else:
            mid = (lat_interval[0] + lat_interval[1]) / 2
            if latitude > mid:
                ch |= bits[bit]
                lat_interval = (mid, lat_interval[1])
            else:
                lat_interval = (lat_interval[0], mid)
        even = not even
        
        if bit < 4:
            bit += 1
        else: 
            geohash[n] = __base32[ch]
            bit = 0
            ch = 0
            n += 1
            
    return ''.join(geohash)

This already provides a significant speedup

from pygeohash import encode

%timeit encode(50, 14)
# 15.3 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit encode_numba(50, 14)
# 2.81 µs ± 84.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

It would make sense to possibly also use int representation. As a quick PoC, I just changed the type to int

@numba.njit()
def encode_numba_int(latitude, longitude, precision=12):
    lat_interval = (-90.0, 90.0)
    lon_interval = (-180.0, 180.0)
    geohash = np.zeros(precision, dtype='int')  #  <-- CHANGE HERE
    bits = np.array([16, 8, 4, 2, 1])
    bit = 0
    ch = 0
    n = 0
    even = True
    while n < precision:
        if even:
            mid = (lon_interval[0] + lon_interval[1]) / 2
            if longitude > mid:
                ch |= bits[bit]
                lon_interval = (mid, lon_interval[1])
            else:
                lon_interval = (lon_interval[0], mid)
        else:
            mid = (lat_interval[0] + lat_interval[1]) / 2
            if latitude > mid:
                ch |= bits[bit]
                lat_interval = (mid, lat_interval[1])
            else:
                lat_interval = (lat_interval[0], mid)
        even = not even
        
        if bit < 4:
            bit += 1
        else: 
            geohash[n] = ch  #  <-- CHANGE HERE
            bit = 0
            ch = 0
            n += 1
            
    return geohash  #  <-- CHANGE HERE

This results in another speedup (we're down to about 1/25 of the original time)

%timeit encode_numba_int(50, 14)
# 641 ns ± 49.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

One could of course change the code so that it doesn't even attempt to create an array with base-32 representation and merely constructs a single int.

If the authors are interested in this direction, I'd be happy to start a PR (after specifying the resulting API a bit). I'd also be happy for more feedback on Numba, as I'm hardly an expert.

@DahnJ
Copy link

DahnJ commented Aug 14, 2021

There also appears to be some recent effort to refactor and speed up PyGeohash here: https://github.com/tastatham/gsoc_dask_geopandas_2021/issues/2

@IlyasMoutawwakil
Copy link
Contributor Author

I've tried doing about the same with the decode function. It's a bit complicated since it uses dictionaries. I tried numba dictionaries but for some reason they only have setters and you can't get an item from them. So finally I made use of the ord builtin function since it's implimented in numba and by doing so you don't even need access to the global variable __base32, some modifications didn't add much performance but I kept them anyway. My numba_decode function is the following:

@njit('int8(unicode_type)')
def base32_to_int(s):
    res = ord(s) - 48
    if res>9: res-=40
    if res>16: res-=1
    if res>18: res-=1
    if res>20: res-=1
    return res

@njit('UniTuple(float64, 4)(unicode_type)')
def numba_decode_exactly(geohash):
    lat_interval_neg, lat_interval_pos, lon_interval_neg, lon_interval_pos = -90, 90, -180, 180
    lat_err, lon_err = 90, 180
    is_even = True
    for c in geohash:
        cd=base32_to_int(c)
        for mask in (16, 8, 4, 2, 1):
            if is_even:  # adds longitude info
                lon_err /= 2
                if cd & mask:
                    lon_interval_neg = (lon_interval_neg + lon_interval_pos) / 2
                else:
                    lon_interval_pos = (lon_interval_neg + lon_interval_pos) / 2
            else:  # adds latitude info
                lat_err /= 2
                if cd & mask:
                    lat_interval_neg = (lat_interval_neg + lat_interval_pos) / 2
                else:
                    lat_interval_pos = (lat_interval_neg + lat_interval_pos) / 2
            is_even = not is_even
    lat = (lat_interval_neg + lat_interval_pos) / 2
    lon = (lon_interval_neg + lon_interval_pos) / 2
    return lat, lon, lat_err, lon_err

@njit('UniTuple(float64, 2)(unicode_type)')
def numba_decode(geohash):
    """
    Decode geohash, returning two float with latitude and longitude
    containing only relevant digits and with trailing zeroes removed.
    """
    
    lat, lon, lat_err, lon_err = numba_decode_exactly(geohash)
    # Format to the number of decimals that are known
    lat_prec = max(1, int(round(-log10(lat_err))))
    lon_prec = max(1, int(round(-log10(lon_err))))
    lat = round(lat, lat_prec)
    lon = round(lon, lon_prec)
    
    return lat, lon

I also modified the way precision is casted into strings then into floats. so that I could find a way to vectorize it after.

The improvement in performance is relative to the precision/size of the geohashes so I fixed it to 12:

geohash = ''.join(random.sample(__base32, 12))
%%timeit
decode(geohash)
# 19.9 µs ± 935 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
numba_decode(geohash)
# 4.37 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

@IlyasMoutawwakil
Copy link
Contributor Author

update: I've made some modifications to speed processing of geohash arrays (an array with n geohashes) and pretty promissing (x10 speedup compared to a numpy vectorization). I don't know if the owners are still alive but I would love to start a PR.

@wdm0006
Copy link
Owner

wdm0006 commented Oct 13, 2021

Would be happy to review a PR if you're still looking at this.

@IlyasMoutawwakil
Copy link
Contributor Author

IlyasMoutawwakil commented Oct 28, 2021

@wdm0006 check the code structure and performance gain on this repo , if it looks worth adding to this package I would submit a PR. Also if you have any ideas how to fully vectorize computations (with some matrix or tensor operations) I would love to implement it.

@wdm0006
Copy link
Owner

wdm0006 commented Oct 29, 2021

I think it's worth doing as an optional path, like if installed as pygeohash[numba] or something like that. The dependencies of the existing library are very light so for many non-performance use cases would probably be preferred. Happy to review a PR with that caveat.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants