Skip to content

Commit

Permalink
fix net.pca bug
Browse files Browse the repository at this point in the history
  • Loading branch information
almazan committed Aug 8, 2019
1 parent 21cb3f8 commit 263e089
Showing 1 changed file with 10 additions and 43 deletions.
53 changes: 10 additions & 43 deletions dirtorch/test_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pickle as pkl
import hashlib


def hash(x):
m = hashlib.md5()
m.update(str(x).encode('utf-8'))
Expand Down Expand Up @@ -84,12 +85,12 @@ def extract_image_features( dataset, transforms, net, ret_imgs=False, same_size=

img_feats = []
trf_images = []
with torch.no_grad(): # important to put it outside!
with torch.no_grad():
for inputs in tqdm.tqdm(loader, desc, total=1+(len(dataset)-1)//batch_size):
imgs = inputs[0]
for i in range(len(imgs)):
if flip and flip.pop(0):
imgs[i] = imgs[i].flip(2) # flip this image horizontally!
imgs[i] = imgs[i].flip(2)
imgs = common.variables(inputs[:1], net.iscuda)[0]
desc = net(imgs)
if ret_imgs: trf_images.append( tocpu(imgs.detach()) )
Expand All @@ -99,7 +100,7 @@ def extract_image_features( dataset, transforms, net, ret_imgs=False, same_size=
img_feats.append( desc.detach() )

img_feats = torch.cat(img_feats, dim=0)
if len(img_feats.shape) == 1: img_feats.unsqueeze_(0) # atleast_2d
if len(img_feats.shape) == 1: img_feats.unsqueeze_(0)

if not same_size:
torch.backends.cudnn.benchmark = old_benchmark
Expand Down Expand Up @@ -197,25 +198,6 @@ def eval_model(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=Non
if detailed: res['APs'+'-'+mode] = apst
res['mAP'+'-'+mode] = float(np.mean([e for e in apst if e>=0])) # Queries with no relevants have an AP of -1

if 'ap' in dbg:
pdb.set_trace()
pyplot(globals())
for query in np.argsort(aps):
subplot_grid(20, 1)
pl.imshow(query_db.get_image(query))
qlabel = query_db.get_label(query)
pl.xlabel('#%d %s' % (query, qlabel))
pl_noticks()
ranked = np.argsort(scores[query])[::-1]
gt = db.get_query_groundtruth(query)[ranked]

for i,idx in enumerate(ranked):
if i+2 > 20: break
subplot_grid(20, i+2)
pl.imshow(db.get_image(idx))
pl.xlabel('#%d %s %g' % (idx, 'OK' if label==qlabel else 'BAD', scores[query,idx]))
pl_noticks()
pdb.set_trace()
except NotImplementedError:
print(" AP not implemented!")

Expand All @@ -236,7 +218,7 @@ def load_model( path, iscuda ):
net = common.switch_model_to_cuda(net, iscuda, checkpoint)
net.load_state_dict(checkpoint['state_dict'])
net.preprocess = checkpoint.get('preprocess', net.preprocess)
if 'pca' in checkpoint: net.pca = checkpoint.get('pca', net.pca)
if 'pca' in checkpoint: net.pca = checkpoint.get('pca')
return net


Expand All @@ -263,7 +245,6 @@ def learn_whiten( dataset, net, trfs='', pooling='mean', threads=8, batch_size=1
parser.add_argument('--trfs', type=str, required=False, default='', nargs='+', help='test transforms (can be several)')
parser.add_argument('--pooling', type=str, default="gem", help='pooling scheme if several trf chains')
parser.add_argument('--gemp', type=int, default=3, help='GeM pooling power')
parser.add_argument('--center-bias', type=float, default=0, help='enforce some center bias')

parser.add_argument('--out-json', type=str, default="", help='path to output json')
parser.add_argument('--detailed', action='store_true', help='return detailed evaluation')
Expand Down Expand Up @@ -296,26 +277,12 @@ def learn_whiten( dataset, net, trfs='', pooling='mean', threads=8, batch_size=1

net = load_model(args.checkpoint, args.iscuda)

if args.center_bias:
assert hasattr(net,'center_bias')
net.center_bias = args.center_bias
if hasattr(net, 'module') and hasattr(net.module,'center_bias'):
net.module.center_bias = args.center_bias

if args.whiten and not hasattr(net, 'pca'):
# Learn PCA if necessary
if os.path.exists(args.whiten):
with open(args.whiten, 'rb') as f:
net.pca = pkl.load(f)
else:
pca_path = '_'.join([args.checkpoint, args.whiten, args.pooling, hash(args.trfs), 'pca.pkl'])
db = datasets.create(args.whiten)
print('Dataset for learning the PCA with whitening:', db)
net.pca = learn_whiten(db, net, pooling=args.pooling, trfs=args.trfs, threads=args.threads)
with open(pca_path, 'wb') as f:
pkl.dump(net.pca, f)

if args.whiten:
net.pca = net.pca[args.whiten]
args.whiten = {'whitenp': args.whitenp, 'whitenv': args.whitenv, 'whitenm': args.whitenm}
else:
net.pca = None
args.whiten = None

# Evaluate
res = eval_model(dataset, net, args.trfs, pooling=args.pooling, gemp=args.gemp, detailed=args.detailed,
Expand Down

0 comments on commit 263e089

Please sign in to comment.