-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cf5c8e2
commit 0aff12c
Showing
116 changed files
with
19,648 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,42 @@ | ||
# captionGAN | ||
Source code for the paper "Speaking the Same Language: Matching Machine to Human Captions by Adversarial Training" | ||
# | ||
|
||
This is built *Python+numpy+theano*. | ||
It's a large codebase containing the code to implement captioning frameworks used in the following papers: | ||
Image captioning: | ||
1. "Speaking the Same Language: Matching Machine to Human Captions by Adversarial Training" (https://arxiv.org/abs/1703.10476) | ||
2. "Paying Attention to Descriptions Generated by Image Captioning Models" (https://arxiv.org/abs/1704.07434) | ||
3. "Exploiting scene context for image captioning" (https://dl.acm.org/citation.cfm?id=2983571) | ||
Video captioning: | ||
4. "Frame-and segment-level features and candidate pool evaluation for video caption generation" (https://arxiv.org/abs/1608.04959) | ||
5. "Video captioning with recurrent networks based on frame-and video-level features and visual content classification" (https://arxiv.org/abs/1512.02949) | ||
|
||
# Instruction on using the code | ||
|
||
1. Make sure you have theano installed and working. As a quick check "import theano" should work without any errors on a python shell | ||
2. The code expects the data files to be in "data/<dataset_name>" directory. It needs a .json file containing all the training/validation/test samples and we need a .mat feature file containin the CNN features for each of the samples. Actual images are only needed for visualisation of results and are not needed during training. | ||
3. The data and some pre-trained models can be downloaded from the below links | ||
data: https://drive.google.com/open?id=0B76QzqVJdOJ5VjlaR294SVV6Z00 | ||
pre-trained: https://drive.google.com/open?id=0B76QzqVJdOJ5TV9FMjhpVmlsTFE | ||
|
||
# Example Usage | ||
|
||
1. Training the adversarial model | ||
|
||
THEANO_FLAGS='mode=FAST_RUN,floatX=float32,device=cuda3' python train_adversarial_caption_gen_v2.py --maxlen 20 -o cvCoco/advDummy --fappend r-dep3-frc80-resnet-1samp-pretrainBOTH-miniBatchDiscr-GUMBHard0p5-smooth3-noMle-featMatchPsentEmbMatch-randInp50d --batch_size 10 --eval_period 0.5 --max_epochs 50 --feature_file fasterRcnn_clasDetFEat80.npy --eval_feature aux_inp --aux_inp_file resnet150_2048-mean.npy -ld 1e-5 -cb 50 --word_encoding_size 512 --sent_encoding_size 400 --solver rmsprop --train_evaluator_only 0 --use_gumbel_mse 1 -lg 1e-6 --eval_model lstm_eval --eval_init trainedModels/advers/evaluators/advmodel_checkpoint_coco_wks-12-46_r-reg-res150mean-5samp-lstmevalonly_318_94.22_EVOnly.p --disk_feature 0 --metrics_to_track meteor cider len lcldiv_1 lcldiv_2 --gumbel_temp_init 0.5 --use_gumbel_hard 1 --hidden_depth 3 --en_residual_conn 1 --n_gen_samples 5 --merge_dim 50 --softmax_smooth_factor 3.0 --use_mle_train 0 --rev_eval 1 --gen_input_noise 1 --gen_feature_matching 1 --continue_training trainedModels/coco/mpi/model_checkpoint_coco_wks-12-46_r-dep3-frc80-resnet150mean_per9.32.p | ||
|
||
2. Generating captions | ||
|
||
THEANO_FLAGS='mode=FAST_RUN,floatX=float32,device=cuda1' python predict_on_images.py cvCoco/advDummy/advmodel_checkpoint_coco_wks-12-46_r-dep3-frc80-resnet-5samp-pretrainBOTH-miniBatchDiscr-GUMBHard0p5-smooth3-noMle-featMatch-randInp50d_55999_15.20_genacc.p --aux_inp_file data/coco/resnet150_2048-mean.npy -f data/coco/fasterRcnn_clasDetFEat80.npy -i imgLists/imgListCOCO_MiniTestSet_ranzato.txt --fname_append ranzatotest_MLE-20Wrd-Smth3-randInpFeatMatch-ResnetMean-56k-beamsearch5 --softmax_smooth_factor 3.0 --labels data/coco/labels.txt --greedy 0 --computelogprob 1 --dobeamsearch 1 -b 5 | ||
|
||
Example image list file is here: | ||
https://drive.google.com/open?id=0B76QzqVJdOJ5NUtEMkx4ZzNKRWM | ||
|
||
|
||
3. Pre-Training the caption generator | ||
|
||
THEANO_FLAGS='mode=FAST_RUN,floatX=float32,device=cuda0' python driver_theano.py -d coco -l 1e-4 --maxlen 20 --decay_rate 0.999 --grad_clip 10.0 --image_encoding_size 512 --word_encoding_size 512 --hidden_size 512 -o cvCoco/salLclExpts --fappend r-dep3-frc80-resnet150mean --worker_status_output_directory statusCoco/c1 --write_checkpoint_ppl_threshold 14 --regc 2.66e-07 --batch_size 256 --eval_period 0.5 --max_epochs 60 --eval_batch_size 256 --aux_inp_file resnet150_2048-14-14.npzl --feature_file fasterRcnn_clasDetFEat80.npy --data_file dataset.json --sample_by_len 1 --lr_decay_st_epoch 1 --lr_decay 0.99 --disk_feature 2 --hidden_depth 3 --en_residual_conn 1 --poolmethod "none mean" | ||
|
||
|
||
Some of the code and structure is based on original neuraltalk code relased by Andrej Karpath at https://github.com/karpathy/neuraltalk | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,264 @@ | ||
import argparse | ||
import json | ||
import os | ||
import random | ||
import scipy.io | ||
import codecs | ||
import numpy as np | ||
import cPickle as pickle | ||
from collections import defaultdict | ||
from nltk.tokenize import word_tokenize | ||
from imagernn.data_provider import getDataProvider | ||
from imagernn.imagernn_utils import decodeGenerator, eval_split, eval_split_theano | ||
|
||
from nltk.align.bleu import BLEU | ||
import math | ||
|
||
# UTILS needed for BLEU score evaluation | ||
def BLEUscore(candidate, references, weights): | ||
p_ns = [BLEU.modified_precision(candidate, references, i) for i, _ in enumerate(weights, start=1)] | ||
if all([x > 0 for x in p_ns]): | ||
s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns)) | ||
bp = BLEU.brevity_penalty(candidate, references) | ||
return bp * math.exp(s) | ||
else: # this is bad | ||
return 0 | ||
|
||
def evalCandidate(candidate, references): | ||
""" | ||
candidate is a single list of words, references is a list of lists of words | ||
written by humans. | ||
""" | ||
b1 = BLEUscore(candidate, references, [1.0]) | ||
b2 = BLEUscore(candidate, references, [0.5, 0.5]) | ||
b3 = BLEUscore(candidate, references, [1/3.0, 1/3.0, 1/3.0]) | ||
return [b1,b2,b3] | ||
|
||
def get_bleu_scores(cands,refs): | ||
open('eval/output', 'w').write('\n'.join(cands)) | ||
for q in xrange(5): | ||
open('eval/reference'+`q`, 'w').write('\n'.join([x[q] for x in refs])) | ||
owd = os.getcwd() | ||
os.chdir('eval') | ||
os.system('./multi-bleu.perl reference < output > scr') | ||
str = open('scr', 'r').read() | ||
bleus = map(float,str.split('=')[1].split('(')[0].split('/')) | ||
os.chdir(owd) | ||
return bleus | ||
|
||
def eval_bleu_all_cand(params, com_dataset): | ||
bleu_array = np.zeros((3,n_imgs*n_sent)) | ||
|
||
# Also load one of the result structures as the template | ||
res_struct = json.load(open(com_dataset['members_results'][0],'r')) | ||
#owd = os.getcwd() | ||
#os.chdir('eval') | ||
|
||
sid = 0 | ||
for i in xrange(n_imgs): | ||
img = com_dataset['images'][i] | ||
refs = [r.values()[0] for r in res_struct['imgblobs'][i]['references']] | ||
|
||
#for q in xrange(5): | ||
# open('reference'+`q`, 'w').write('\n'.join([x[q] for x in refs])) | ||
|
||
for sent in img['sentences']: | ||
#os.system('./multi-bleu.perl reference <<<"%s" > scr'%(sent['raw'])) | ||
#str = open('scr', 'r').read() | ||
#bleus = map(float,str.split('=')[1].split('(')[0].split('/')) | ||
bleus = evalCandidate(sent['raw'],refs) | ||
bleu_array[:,sid] = bleus | ||
sid +=1 | ||
if ((i) % 500) == 0 : | ||
print('At %d\r'%i) | ||
|
||
#os.chdir(owd) | ||
|
||
return bleu_array | ||
|
||
def evaluate_decision(params, com_dataset, eval_array): | ||
indx = 0 | ||
n_memb = com_dataset['n_memb'] | ||
n_sent = com_dataset['n_sent'] | ||
n_imgs = len(com_dataset['images']) | ||
|
||
all_references = [] | ||
all_candidates = [] | ||
|
||
# Also load one of the result structures as the template | ||
res_struct = json.load(open(com_dataset['members_results'][0],'r')) | ||
|
||
scr = eval_array.sum(axis=0) | ||
|
||
for i in xrange(n_imgs): | ||
img = com_dataset['images'][i] | ||
|
||
curr_scr = scr[i*n_sent: (i+1)* n_sent] | ||
best = np.argmax(curr_scr) | ||
|
||
res_struct['imgblobs'][i]['candidate']['logprob'] = curr_scr[best] | ||
res_struct['imgblobs'][i]['candidate']['text'] = img['sentences'][best]['raw'] | ||
|
||
refs = [r.values()[0] for r in res_struct['imgblobs'][i]['references']] | ||
|
||
#calculate bleu of each candidate with reference | ||
|
||
|
||
all_references.append(refs) | ||
all_candidates.append(img['sentences'][best]['raw']) | ||
|
||
print 'writing intermediate files into eval/' | ||
# invoke the perl script to get BLEU scores | ||
print 'invoking eval/multi-bleu.perl script...' | ||
bleus = get_bleu_scores(all_candidates, all_references) | ||
res_struct['FinalBleu'] = bleus | ||
print bleus | ||
|
||
print 'saving result struct to %s' % (params['result_struct_filename'], ) | ||
json.dump(res_struct, open(params['result_struct_filename'], 'w')) | ||
|
||
|
||
def hold_comittee_discussion(params, com_dataset): | ||
|
||
n_memb = com_dataset['n_memb'] | ||
n_sent = com_dataset['n_sent'] | ||
n_imgs = len(com_dataset['images']) | ||
|
||
eval_array = np.zeros((n_memb,n_imgs*n_sent)) | ||
model_id = 0 | ||
for mod in com_dataset['members_model']: | ||
checkpoint = pickle.load(open(mod, 'rb')) | ||
checkpoint_params = checkpoint['params'] | ||
dataset = checkpoint_params['dataset'] | ||
model_npy = checkpoint['model'] | ||
|
||
checkpoint_params['use_theano'] = 1 | ||
|
||
if 'image_feat_size' not in checkpoint_params: | ||
checkpoint_params['image_feat_size'] = 4096 | ||
|
||
checkpoint_params['data_file'] = params['jsonFname'].rsplit('/')[-1] | ||
dp = getDataProvider(checkpoint_params) | ||
|
||
ixtoword = checkpoint['ixtoword'] | ||
|
||
blob = {} # output blob which we will dump to JSON for visualizing the results | ||
blob['params'] = params | ||
blob['checkpoint_params'] = checkpoint_params | ||
blob['imgblobs'] = [] | ||
|
||
# iterate over all images in test set and predict sentences | ||
BatchGenerator = decodeGenerator(checkpoint_params) | ||
|
||
BatchGenerator.build_eval_other_sent(BatchGenerator.model_th, checkpoint_params,model_npy) | ||
|
||
eval_batch_size = params.get('eval_batch_size',100) | ||
eval_max_images = params.get('eval_max_images', -1) | ||
wordtoix = checkpoint['wordtoix'] | ||
|
||
split = 'test' | ||
print 'evaluating %s performance in batches of %d' % (split, eval_batch_size) | ||
logppl = 0 | ||
logppln = 0 | ||
nsent = 0 | ||
gen_fprop = BatchGenerator.f_eval_other | ||
blob['params'] = params | ||
c_id = 0 | ||
for batch in dp.iterImageSentencePairBatch(split = split, max_batch_size = eval_batch_size, max_images = eval_max_images): | ||
xWd, xId, maskd, lenS = dp.prepare_data(batch,wordtoix) | ||
eval_array[model_id, c_id:c_id + xWd.shape[1]] = gen_fprop(xWd, xId, maskd) | ||
c_id += xWd.shape[1] | ||
|
||
model_id +=1 | ||
|
||
# Calculate oracle scores | ||
bleu_array = eval_bleu_all_cand(params,com_dataset) | ||
eval_results = {} | ||
eval_results['logProb_feat'] = eval_array | ||
eval_results['OracleBleu'] = bleu_array | ||
#Save the mutual evaluations | ||
|
||
params['comResFname'] = 'committee_evalSc_%s.json' % (params['fappend']) | ||
com_dataset['com_evaluation'] = params['comResFname'] | ||
pickle.dump(eval_results, open(params['comResFname'], "wb")) | ||
json.dump(com_dataset,open(params['jsonFname'], 'w')) | ||
|
||
return eval_array | ||
|
||
|
||
def main(params): | ||
dataset = 'coco' | ||
data_file = 'dataset.json' | ||
|
||
# !assumptions on folder structure | ||
dataset_root = os.path.join('data', dataset) | ||
|
||
result_list = open(params['struct_list'], 'r').read().splitlines() | ||
|
||
# Load all result files | ||
result_struct = [json.load(open(res,'r')) for res in result_list] | ||
|
||
# load the dataset into memory | ||
dataset_path = os.path.join(dataset_root, data_file) | ||
print 'BasicDataProvider: reading %s' % (dataset_path, ) | ||
dB = json.load(open(dataset_path, 'r')) | ||
|
||
res_idx = 0 | ||
|
||
com_dataset = {} | ||
com_dataset['dataset'] = 'coco'; | ||
com_dataset['members_results'] = result_list; | ||
com_dataset['members_model'] = list(set([res['params']['checkpoint_path'] for res in result_struct])); | ||
com_dataset['images'] = [] | ||
com_dataset['n_memb'] = len(com_dataset['members_model']) | ||
com_dataset['n_sent'] = len(com_dataset['members_results']) | ||
|
||
|
||
#pick only test images | ||
# We are doing this circus in order to reuse the data provider class to form nice batches when doing evaluation | ||
# The data provider expects the database files to be in original "dataset.json" format! | ||
# Hence we copy all necessary fields from dataset.json and replace the refernce sentences with the sentences | ||
# generated by our models | ||
for img in dB['images']: | ||
if img['split'] == 'test': | ||
# Copy everything! | ||
com_dataset['images'].append(img) | ||
|
||
# delete reference sentences | ||
com_dataset['images'][-1]['sentences'] = [] | ||
for res_st in result_struct: | ||
#assert img['imgid'] == res_st['imgblobs'][res_idx]['imgid'], 'Ids dont match, Test %d %d'%(res_idx, mod_cnt) | ||
com_dataset['images'][-1]['sentences'].append( {'img_id': img['imgid'], | ||
'raw': res_st['imgblobs'][res_idx]['candidate']['text'], | ||
'sentid':res_st['params']['beam_size'], | ||
'mid':com_dataset['members_model'].index(res_st['params']['checkpoint_path']), | ||
'tokens':word_tokenize(res_st['imgblobs'][res_idx]['candidate']['text']) | ||
}) | ||
|
||
res_idx += 1 | ||
if res_idx == 5000: | ||
break; | ||
|
||
print 'Done with %d !Now writing back dataset ' % (res_idx) | ||
params['jsonFname'] = 'committee_struct_%s.json' % (params['fappend']) | ||
params['jsonFname'] = os.path.join(dataset_root, params['jsonFname']) | ||
json.dump(com_dataset,open(params['jsonFname'], 'w')) | ||
return com_dataset, params | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('struct_list', type=str, help='the input list of result structures to form committee from') | ||
parser.add_argument('--fappend', type=str, default='', help='str to append to routput files') | ||
parser.add_argument('--result_struct_filename', type=str, default='committee_result.json', help='filename of the result struct to save') | ||
|
||
args = parser.parse_args() | ||
params = vars(args) # convert to ordinary dict | ||
print 'parsed parameters:' | ||
print json.dumps(params, indent = 2) | ||
|
||
com_dataset, params = main(params) | ||
eval_array = hold_comittee_discussion(params,com_dataset) | ||
#evaluate_decision(params, com_dataset, eval_array) |
Oops, something went wrong.