diff --git a/lshash/lshash.py b/lshash/lshash.py index 5c895a6..5c529d9 100644 --- a/lshash/lshash.py +++ b/lshash/lshash.py @@ -204,6 +204,24 @@ def index(self, input_point, extra_data=None): table.append_val(self._hash(self.uniform_planes[i], input_point), value) + def code(self, input_point): + """ Calculate LSH code for a single input point. Returns one code of + length `hash_size` for each `hash_table`. + + :param input_point: + A list, or tuple, or numpy ndarray object that contains numbers + only. The dimension needs to be 1 * `input_dim`. + This object will be converted to Python tuple and stored in the + selected storage. + """ + + if isinstance(input_point, np.ndarray): + input_point = input_point.tolist() + + return [self._hash(self.uniform_planes[i], input_point) + for i in xrange(self.num_hashtables)] + + def query(self, query_point, num_results=None, distance_func=None): """ Takes `query_point` which is either a tuple or a list of numbers, returns `num_results` of results as a list of tuples that are ranked diff --git a/tests/test_code.py b/tests/test_code.py new file mode 100644 index 0000000..36a105f --- /dev/null +++ b/tests/test_code.py @@ -0,0 +1,16 @@ +import lshash +import numpy as np + +def test_code(): + """ test codes are of the correct length and work with multiple hash tables """ + l = lshash.LSHash(10, 20) + + assert len(l.code(np.random.randn(20))) == 1 + assert len(l.code(np.random.randn(20))[0]) == 10 + + l = lshash.LSHash(10, 20, num_hashtables=3) + + assert len(l.code(np.random.randn(20))) == 3 + + for hash in l.code(np.random.randn(20)): + assert len(hash) == 10