Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug causing beams to not be reordered by log probability. #86

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
31 changes: 27 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Neuraltalk2-pytorch
# ImageCaptioning.pytorch

Changes compared to neuraltalk2.
This is an image captioning codebase in PyTorch. If you are familiar with neuraltalk2, here are the differences compared to neuraltalk2.
- Instead of using random split, we use [karpathy's train-val-test split](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip).
- Instead of including the convnet in the model, we use preprocessed features. (finetuneable cnn version is in the branch **with_finetune**)
- Use resnet instead of vgg; the feature extraction method is the same as in self-critical: run cnn on original image and adaptively average pool the last conv layer feature to fixed size .
- Much more models (you can check out models folder). The latest topdown model can achieve 1.07 Cider score on Karpathy's test split with beam size 5.

## Requirements
Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3)
PyTorch 0.2 (along with torchvision)
PyTorch 0.4.1 (along with torchvision)

You need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`.

Expand All @@ -31,6 +31,7 @@ Once we have these, we can now invoke the `prepro_*.py` script, which will read
```bash
$ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk
$ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT

```

`prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`.
Expand All @@ -39,6 +40,12 @@ $ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_di

(Check the prepro scripts for more options, like other resnet models or other attention sizes.)

**Legacy:** previously we extract features into separate npy/npz files for each image, but it would be slower to load on some NFS and also to copy them around. We now save all the features in h5 file. If you want to convert from previous npy/npz files to h5 file, you can use run

```bash
$ python scripts/convert_old.py --input_json data/dataset_coco.json --fc_input_dir data/cocotalk_fc/ --att_input_dir data/cocotalk_att/ --fc_output_dir data/cocotalk_fc --att_output_dir data/cocotalk_att/
```

**Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset.

### Start training
Expand Down Expand Up @@ -97,6 +104,22 @@ The defualt split to evaluate is test. The default inference method is greedy de

**Live demo**. Not supported now. Welcome pull request.

## Reference
If you find this implementation helpful, please consider citing this repo:

```
@misc{Luo2017,
author = {Ruotian Luo},
title = {An Image Captioning codebase in PyTorch},
year = {2017},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/ruotianluo/ImageCaptioning.pytorch}},
}
```

Of course, please cite the original paper of models you are using (You can find references in the model files).

## Acknowledgements

Thanks the original [neuraltalk2](https://github.com/karpathy/neuraltalk2) and awesome PyTorch team.
Thanks the original [neuraltalk2](https://github.com/karpathy/neuraltalk2) and awesome PyTorch team.
157 changes: 93 additions & 64 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,17 @@
import numpy as np
import random

import torch
import torch.utils.data as data

import multiprocessing

def get_npy_data(ix, fc_file, att_file, use_att):
if use_att == True:
return (np.load(fc_file), np.load(att_file)['feat'], ix)
else:
return (np.load(fc_file), np.zeros((1,1,1)), ix)

class DataLoader(data.Dataset):

def reset_iterator(self, split):
del self._prefetch_process[split]
self._prefetch_process[split] = BlobFetcher(split, self, split=='train')
self._prefetch_process[split] = BlobFetcher(split,
self, split == 'train')
self.iterators[split] = 0

def get_vocab_size(self):
Expand All @@ -35,22 +30,40 @@ def get_vocab(self):
def get_seq_length(self):
return self.seq_length

def read_files(self):
self.feats_fc = h5py.File(os.path.join(
self.opt.input_fc_dir, 'feats_fc.h5'), 'r')
self.feats_att = h5py.File(os.path.join(
self.opt.input_att_dir, 'feats_att.h5'), 'r')

def get_data(self, ix):
self.read_files()
index = str(self.info['images'][ix]['id'])
if self.use_att:
return (np.array(self.feats_fc[index]).astype('float32'),
np.array(self.feats_att[index]).astype('float32'), ix)
else:
return (np.array(self.feats_fc[index]).astype('float32'),
np.zeros((1, 1, 1)).astype('float32'), ix)

def __init__(self, opt):
self.opt = opt
self.batch_size = self.opt.batch_size
self.seq_per_img = opt.seq_per_img
self.use_att = getattr(opt, 'use_att', True)

# load the json file which contains additional information about the dataset
# load json file which contains additional information about dataset
print('DataLoader loading json file: ', opt.input_json)
self.info = json.load(open(self.opt.input_json))
self.ix_to_word = self.info['ix_to_word']
self.vocab_size = len(self.ix_to_word)
print('vocab size is ', self.vocab_size)

# open the hdf5 file
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_label_h5)
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
print('DataLoader loading h5 file: ', opt.input_fc_dir,
opt.input_att_dir, opt.input_label_h5)
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r',
driver='core')

self.input_fc_dir = self.opt.input_fc_dir
self.input_att_dir = self.opt.input_att_dir
Expand All @@ -64,7 +77,7 @@ def __init__(self, opt):
self.label_end_ix = self.h5_label_file['label_end_ix'][:]

self.num_images = self.label_start_ix.shape[0]
print('read %d image features' %(self.num_images))
print('read %d image features' % (self.num_images))

# separate out indexes for each of the provided splits
self.split_ix = {'train': [], 'val': [], 'test': []}
Expand All @@ -76,120 +89,132 @@ def __init__(self, opt):
self.split_ix['val'].append(ix)
elif img['split'] == 'test':
self.split_ix['test'].append(ix)
elif opt.train_only == 0: # restval
elif opt.train_only == 0: # restval
self.split_ix['train'].append(ix)

print('assigned %d images to split train' %len(self.split_ix['train']))
print('assigned %d images to split val' %len(self.split_ix['val']))
print('assigned %d images to split test' %len(self.split_ix['test']))
print('assigned %d images to split train' % len(self.split_ix['train']))
print('assigned %d images to split val' % len(self.split_ix['val']))
print('assigned %d images to split test' % len(self.split_ix['test']))

self.iterators = {'train': 0, 'val': 0, 'test': 0}
self._prefetch_process = {} # The three prefetch process

self._prefetch_process = {} # The three prefetch process
for split in self.iterators.keys():
self._prefetch_process[split] = BlobFetcher(split, self, split=='train')
self._prefetch_process[split] = BlobFetcher(split,
self,
split == 'train')
# Terminate the child process when the parent exists

def cleanup():
print('Terminating BlobFetcher')
for split in self.iterators.keys():
del self._prefetch_process[split]

import atexit
atexit.register(cleanup)

def get_batch(self, split, batch_size=None, seq_per_img=None):
batch_size = batch_size or self.batch_size
seq_per_img = seq_per_img or self.seq_per_img

fc_batch = [] # np.ndarray((batch_size * seq_per_img, self.opt.fc_feat_size), dtype = 'float32')
att_batch = [] # np.ndarray((batch_size * seq_per_img, 14, 14, self.opt.att_feat_size), dtype = 'float32')
label_batch = np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'int')
mask_batch = np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'float32')
fc_batch = []
att_batch = []
label_batch = np.zeros(
[batch_size * seq_per_img, self.seq_length + 2], dtype='int')
mask_batch = np.zeros(
[batch_size * seq_per_img, self.seq_length + 2], dtype='float32')

wrapped = False

infos = []
gts = []

for i in range(batch_size):
import time
t_start = time.time()
# fetch image
tmp_fc, tmp_att,\
ix, tmp_wrapped = self._prefetch_process[split].get()
fc_batch += [tmp_fc] * seq_per_img
att_batch += [tmp_att] * seq_per_img

# fetch the sequence labels
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
ix1 = self.label_start_ix[ix] - 1 # label_start_ix starts from 1
ix2 = self.label_end_ix[ix] - 1
ncap = ix2 - ix1 + 1 # number of captions available for this image
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
ncap = ix2 - ix1 + 1 # number of captions available for this image
assert ncap > 0, 'an image does not have any label.'

if ncap < seq_per_img:
# we need to subsample (with replacement)
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
seq = np.zeros([seq_per_img, self.seq_length], dtype='int')
for q in range(seq_per_img):
ixl = random.randint(ix1,ix2)
seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length]
ixl = random.randint(ix1, ix2)
seq[q, :] = self.h5_label_file['labels'][ixl,
:self.seq_length]
else:
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, :self.seq_length]

label_batch[i * seq_per_img : (i + 1) * seq_per_img, 1 : self.seq_length + 1] = seq
seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img,
:self.seq_length]

label_batch[i * seq_per_img: (i + 1) * seq_per_img,
1: self.seq_length + 1] = seq

if tmp_wrapped:
wrapped = True

# Used for reward evaluation
gts.append(self.h5_label_file['labels'][self.label_start_ix[ix] - 1: self.label_end_ix[ix]])

gts.append(
self.h5_label_file['labels'][self.label_start_ix[ix] - 1:
self.label_end_ix[ix]])

# record associated info as well
info_dict = {}
info_dict['ix'] = ix
info_dict['id'] = self.info['images'][ix]['id']
info_dict['file_path'] = self.info['images'][ix]['file_path']
infos.append(info_dict)
#print(i, time.time() - t_start)

# generate mask
t_start = time.time()
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, label_batch)))
nonzeros = np.array(list(map(lambda x: (x != 0).sum() + 2, label_batch)))
for ix, row in enumerate(mask_batch):
row[:nonzeros[ix]] = 1
#print('mask', time.time() - t_start)

data = {}
data['fc_feats'] = np.stack(fc_batch)
data['att_feats'] = np.stack(att_batch)
data['labels'] = label_batch
data['gts'] = gts
data['masks'] = mask_batch
data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
data['masks'] = mask_batch
data['bounds'] = {'it_pos_now': self.iterators[split],
'it_max': len(self.split_ix[split]),
'wrapped': wrapped}
data['infos'] = infos

return data

# It's not coherent to make DataLoader a subclass of Dataset, but essentially, we only need to implement the following to functions,
# so that the torch.utils.data.DataLoader can load the data according the index.
# However, it's minimum change to switch to pytorch data loading.
# It's not coherent to make DataLoader a subclass of Dataset,
# but essentially, we only need to implement the following to functions,
# so that the torch.utils.data.DataLoader can load the data according
# the index. However, it's minimum change to switch to pytorch data loading
def __getitem__(self, index):
"""This function returns a tuple that is further passed to collate_fn
"""
ix = index #self.split_ix[index]
return get_npy_data(ix, \
os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy'),
os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'),
self.use_att
)
ix = index # self.split_ix[index]
return self.get_data(ix)

def __len__(self):
return len(self.info['images'])


class ArraySampler(data.sampler.SubsetRandomSampler):
def __iter__(self):
return iter(self.indices)


class BlobFetcher():
"""Experimental class for prefetching blobs in a separate process."""
def __init__(self, split, dataloader, if_shuffle=False):
"""
db is a list of tuples containing: imcrop_name, caption, bbox_feat of gt box, imname
db is a list of tuples containing: imcrop_name,
caption, bbox_feat of gt box, imname
"""
self.split = split
self.dataloader = dataloader
Expand All @@ -199,17 +224,22 @@ def __init__(self, split, dataloader, if_shuffle=False):
def reset(self):
"""
Two cases:
1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator
2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already.
1. not hasattr(self, 'split_loader'): Resume from previous training.
Create the dataset given the saved split_ix and iterator
2. wrapped: a new epoch, the split_ix and iterator have been updated in
the get_minibatch_inds already.
"""
# batch_size is 0, the merge is done in DataLoader class
self.split_loader = iter(data.DataLoader(dataset=self.dataloader,
batch_size=1,
sampler=self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:],
shuffle=False,
pin_memory=True,
num_workers=multiprocessing.cpu_count(),
collate_fn=lambda x: x[0]))
sampler = ArraySampler(
self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:])
self.split_loader = iter(
data.DataLoader(dataset=self.dataloader,
batch_size=1,
sampler=sampler,
shuffle=False,
pin_memory=True,
num_workers=multiprocessing.cpu_count(),
collate_fn=lambda x: x[0]))

def _get_next_minibatch_inds(self):
max_index = len(self.dataloader.split_ix[self.split])
Expand All @@ -227,7 +257,7 @@ def _get_next_minibatch_inds(self):
self.dataloader.iterators[self.split] = ri_next

return ix, wrapped

def get(self):
if not hasattr(self, 'split_loader'):
self.reset()
Expand All @@ -236,7 +266,6 @@ def get(self):
tmp = self.split_loader.next()
if wrapped:
self.reset()

assert tmp[2] == ix, "ix not equal"

return tmp + [wrapped]
return tmp + [wrapped]
8 changes: 4 additions & 4 deletions dataloaderraw.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ def get_batch(self, split, batch_size=None):
img = np.concatenate((img, img, img), axis=2)

img = img.astype('float32')/255.0
img = torch.from_numpy(img.transpose([2,0,1])).cuda()
img = Variable(preprocess(img), volatile=True)
tmp_fc, tmp_att = self.my_resnet(img)
img = torch.from_numpy(img.transpose([2, 0, 1])).cuda()
with torch.no_grad():
img = Variable(preprocess(img))
tmp_fc, tmp_att = self.my_resnet(img)

fc_batch[i] = tmp_fc.data.cpu().float().numpy()
att_batch[i] = tmp_att.data.cpu().float().numpy()
Expand All @@ -136,4 +137,3 @@ def get_vocab_size(self):

def get_vocab(self):
return self.ix_to_word

Loading