Assignment completed as a part of the Advanced Natural Language Processing Course at MSc Cognitive Systems: Language, Learning and Reasoning
by: Saswat Dash, Pace Bailey, Dimitrije Ristic
Download the data [here](https://drive.google.com/file/d/1byX4wpe1UjyCVyYrT04sW17NnycKAK7N/view)
In order to execute the script the following libraries are imported:
1. argparse -> enables running the code from the terminal
2. Pytorch -> enables model training
3. Matplotlib -> enables plots and visualization
4. Numpy -> matrix and math operations
5. PIL -> image pre-procesing
6. Tqdm -> visualizes progress during batch execution
7. Huggingface Transformers -> enables access to pre-trained CLIP Model and Tokenizer
8. Json -> enables construction of Json files
9. h5py -> enables efficient data management
├── data
│ ├── embeddings
│ │ ├── train_image_embeddings ---> contains individual h5py img embedding files
│ │ ├── train_text_embeddings ---> contains individual h5py text embedding files
│ │ └── wrapper ---> contains h5py embedding wrapper files
│ ├── features ---> contains image and text pickled embeddings
│ ├── metrics.json
│ └── train ---> location of the dataset
│ ├── train.data.v1.txt
│ ├── train.gold.v1.txt
│ └── train_images_v1
├── data_preparation.py
├── evaluation.py
├── finetune_clip_models.py
├── helper.py
├── language_model.py
├── main.py
└── utils.py
The code consists of two independent processes:
- Data pre-processing
- Model selection and training
Activated using the following terminal command python main.py --prepare
1. Encodes text -> Tokenizes text data and extracts embeddings for the tokens from CLIP. Each embedding is stored in a separate h5py file.
2. Encodes images -> Reduces the size of each image to 100x100 and converts them to RGB. Then, image embeddings are extracted from CLIP. Each embedding is stored in a separate h5py file.
3. Creates wrappers -> Empty wrapper h5py files are generated.
4. Wraps image files -> The file resembles the structure of python dictionaries. Each key is the name of the .jpg image file. Each value is an h5py external link to the file containing embeddings for that image.
5. Wraps text files -> The file resembles the structure of python dictionaries. Each key is the index of the input phrase. Each value is an h5py external link to the file containing embeddings for that phrase.
6. Extracts text features -> Text features are extracted from h5py and stored inside a single tensor.
7. Extracts image features -> Image features are extracted from h5py and stored inside a dictionary. The key is the index of the text phrase, the value is a tensor consisting of 10 possible image embeddings.
8. Saves features -> Extracted and formatted embeddings are stored inside pickle files in the following directory: ‘./data/features’
Many of the steps in the process were added in order to handle severe memory issues during feature extraction. Therefore, the process was built around the use of .h5py files.
H5py files are a type of file format that is used to store and organize large quantities of numerical data. Conveniently, the data can be stored and easily manipulated through numpy.
Activated using the following terminal command python main.py --choose_model --loss_function
- 'CLIP_0' -> No training conducted. Inference done using the embeddings extracted from CLIP.
- 'CLIP_1' -> 1 GELU Layer and 1 Linear Layer.
- 'CLIP_2' -> 2 Fully connected Linear Layers and 2 GELU Layers.
- 'CLIP_3' -> 1 LSTM Layer, 1 fully connected Linear Layer, 2 GELU Layers.
-
'Contrastive Cosine Loss'
-
'Cross Entropy Loss'
1. Loads the dataset -> Loads the saved pickle files with the embeddings, constructs a dataset, and initializes variables with corresponding elements of the dataset.
2. Splits the dataset -> the dataset is split into training and testing dataloaders using the following ratio: 75% Training, 25% Testing.
3. Runs the training -> Based on the selection of the model and the loss function, NNs are initialized and training is executed.
4. Plots loss -> Upon training completion, loss, MRR, and Hit rate are plotted against the number of epochs.
5. Executes testing -> Calculates the MRR and Hit rate over the test set. Prints the average values.