Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential speed up of projector training. #23

Open
JackGodfrey122 opened this issue Mar 29, 2023 · 1 comment
Open

Potential speed up of projector training. #23

JackGodfrey122 opened this issue Mar 29, 2023 · 1 comment

Comments

@JackGodfrey122
Copy link

Hey Daniel, I hope you're doing well!

I was hoping to get your thoughts on a potential method to drastically speed up the training time of the projector. Currently, a call is made to a fitted search structure such as a BallTree. Whilst this does indeed accept batches of data, it still appears to be quite a large bottleneck. I have briefly profiled the train_projector.py script (only for 500 iterations with no results generated) and found that the total run time is roughly 20 seconds, with 12 seconds of that time being spent in a call to the BallTree.

I don't think there is any need to have this call to the nearest neighbor search during the training process. Rather, let's generate this offline first using optimized tools such as FAISS.

To make a fair comparison, I generated 16,000,000 queries (32 * 500,000) using the same method you do, a form of scaled Gaussian noise added to a sample of the feature database. I then query against a FAISS index which uses Euclidean distance measure under the hood. The output of this is 16,000,000 queries and their associated closest feature indexes. Generating this data takes roughly 15 minutes on an AMD Ryzen 9 5950X. I don't think FAISS is supported on GPU for Windows (at least not officially), but I think I remember them claiming that on average the GPU implementation is roughly 5x faster than CPU, so maybe there are more gains to be made...

Once this data has been generated, we can bypass the BallTree call altogether and just generate our batches from this offline data instead. So we can now train the projector on an NVIDIA 3080 with a batch size of 1024 for 50,000 iterations in roughly 3 minutes, with a final loss of 0.520.

Here is the code to produce the offline data:

import os
import struct
import time

import faiss
import numpy as np


class ProjectorDatasetBuilder:

    def __init__(self):
        pass

    def load_features_from_file(self, filename: str) -> None:

        with open(filename, 'rb') as f:
            
            self.n_frames, self.n_features = struct.unpack('II', f.read(8))
            self.features: np.ndarray = np.frombuffer(f.read(self.n_frames*self.n_features*4), dtype=np.float32, count=self.n_frames*self.n_features).reshape([self.n_frames, self.n_features])
    
        self.features_noise_std: np.ndarray = self.features.std(axis=0) + 1.0

    def make_queries(self, n_queries: int) -> None:
        samples = np.random.randint(0, self.n_frames, size=[n_queries]) # n_queries
        n_sigma = np.random.uniform(size=[n_queries, 1]).astype(np.float32)  # n_queries x 1
        noise = np.random.normal(size=[n_queries, self.n_features]).astype(np.float32) # n_queries x n_features
        
        # here we scale our noise and add to our samples
        queries = self.features[samples] + self.features_noise_std * n_sigma * noise
        self.queries = queries
    
    def build_faiss_index(self) -> None:
        start_time = time.time()
        self.index = faiss.IndexFlatL2(self.n_features)
        self.index.add(self.features)
        index_build_time = time.time() - start_time
        print(f'Time taken to build index: {index_build_time} seconds')
        self._sanity_check()
    
    def _sanity_check(self) -> None:
        """
        Here we are passing in the first 5 feature rows into our search.
        We expect to see the returned distances as 0, and the indexes to be 0 through 4,
        as these data points are explicitly part of the search index.
        """
        distances, indexes = self.index.search(self.features[:5], 1)
        print('Sanity check...')
        for i, (d, ii) in enumerate(zip(distances, indexes)):
            print(f'Closest neighbor of point {i} is {ii} with distance {d}')
    

    def get_projector_data(self) -> None:
        start_time = time.time()
        _, closest_indexes = self.index.search(self.queries, 1) # main call
        search_time = time.time() - start_time
        print(f'Time taken for search: {search_time} seconds.')
        self.closest_indexes = closest_indexes
    
    def save(
            self,
            dst_folder: str,
            query_filename: str ='test_queries.npy',
            nearest_indexes_filename: str ='test_nearest_indexes.npy'
            ):
        np.save(os.path.join(dst_folder, query_filename), self.queries)
        np.save(os.path.join(dst_folder, nearest_indexes_filename), self.closest_indexes)


if __name__ == "__main__":

    np.random.seed(1234)
    features_filename = "./features.bin"
    output_folder = "./output"
    num_queries = 16_000_000
    builder = ProjectorDatasetBuilder()
    builder.load_features_from_file(features_filename)
    builder.make_queries(num_queries)
    builder.build_faiss_index()
    builder.get_projector_data()
    builder.save(output_folder)

Just thought you might be interested!

@orangeduck
Copy link
Owner

Hi Jack,

Great idea! Really nice optimization 👍

Dan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants