From 263e089ecc1181de993ec650ddc9e564537bea9b Mon Sep 17 00:00:00 2001 From: Jon Almazan Date: Thu, 8 Aug 2019 11:54:16 +0100 Subject: [PATCH] fix net.pca bug --- dirtorch/test_dir.py | 53 +++++++++----------------------------------- 1 file changed, 10 insertions(+), 43 deletions(-) diff --git a/dirtorch/test_dir.py b/dirtorch/test_dir.py index e01ac3f..4c986c6 100644 --- a/dirtorch/test_dir.py +++ b/dirtorch/test_dir.py @@ -19,6 +19,7 @@ import pickle as pkl import hashlib + def hash(x): m = hashlib.md5() m.update(str(x).encode('utf-8')) @@ -84,12 +85,12 @@ def extract_image_features( dataset, transforms, net, ret_imgs=False, same_size= img_feats = [] trf_images = [] - with torch.no_grad(): # important to put it outside! + with torch.no_grad(): for inputs in tqdm.tqdm(loader, desc, total=1+(len(dataset)-1)//batch_size): imgs = inputs[0] for i in range(len(imgs)): if flip and flip.pop(0): - imgs[i] = imgs[i].flip(2) # flip this image horizontally! + imgs[i] = imgs[i].flip(2) imgs = common.variables(inputs[:1], net.iscuda)[0] desc = net(imgs) if ret_imgs: trf_images.append( tocpu(imgs.detach()) ) @@ -99,7 +100,7 @@ def extract_image_features( dataset, transforms, net, ret_imgs=False, same_size= img_feats.append( desc.detach() ) img_feats = torch.cat(img_feats, dim=0) - if len(img_feats.shape) == 1: img_feats.unsqueeze_(0) # atleast_2d + if len(img_feats.shape) == 1: img_feats.unsqueeze_(0) if not same_size: torch.backends.cudnn.benchmark = old_benchmark @@ -197,25 +198,6 @@ def eval_model(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=Non if detailed: res['APs'+'-'+mode] = apst res['mAP'+'-'+mode] = float(np.mean([e for e in apst if e>=0])) # Queries with no relevants have an AP of -1 - if 'ap' in dbg: - pdb.set_trace() - pyplot(globals()) - for query in np.argsort(aps): - subplot_grid(20, 1) - pl.imshow(query_db.get_image(query)) - qlabel = query_db.get_label(query) - pl.xlabel('#%d %s' % (query, qlabel)) - pl_noticks() - ranked = np.argsort(scores[query])[::-1] - gt = db.get_query_groundtruth(query)[ranked] - - for i,idx in enumerate(ranked): - if i+2 > 20: break - subplot_grid(20, i+2) - pl.imshow(db.get_image(idx)) - pl.xlabel('#%d %s %g' % (idx, 'OK' if label==qlabel else 'BAD', scores[query,idx])) - pl_noticks() - pdb.set_trace() except NotImplementedError: print(" AP not implemented!") @@ -236,7 +218,7 @@ def load_model( path, iscuda ): net = common.switch_model_to_cuda(net, iscuda, checkpoint) net.load_state_dict(checkpoint['state_dict']) net.preprocess = checkpoint.get('preprocess', net.preprocess) - if 'pca' in checkpoint: net.pca = checkpoint.get('pca', net.pca) + if 'pca' in checkpoint: net.pca = checkpoint.get('pca') return net @@ -263,7 +245,6 @@ def learn_whiten( dataset, net, trfs='', pooling='mean', threads=8, batch_size=1 parser.add_argument('--trfs', type=str, required=False, default='', nargs='+', help='test transforms (can be several)') parser.add_argument('--pooling', type=str, default="gem", help='pooling scheme if several trf chains') parser.add_argument('--gemp', type=int, default=3, help='GeM pooling power') - parser.add_argument('--center-bias', type=float, default=0, help='enforce some center bias') parser.add_argument('--out-json', type=str, default="", help='path to output json') parser.add_argument('--detailed', action='store_true', help='return detailed evaluation') @@ -296,26 +277,12 @@ def learn_whiten( dataset, net, trfs='', pooling='mean', threads=8, batch_size=1 net = load_model(args.checkpoint, args.iscuda) - if args.center_bias: - assert hasattr(net,'center_bias') - net.center_bias = args.center_bias - if hasattr(net, 'module') and hasattr(net.module,'center_bias'): - net.module.center_bias = args.center_bias - - if args.whiten and not hasattr(net, 'pca'): - # Learn PCA if necessary - if os.path.exists(args.whiten): - with open(args.whiten, 'rb') as f: - net.pca = pkl.load(f) - else: - pca_path = '_'.join([args.checkpoint, args.whiten, args.pooling, hash(args.trfs), 'pca.pkl']) - db = datasets.create(args.whiten) - print('Dataset for learning the PCA with whitening:', db) - net.pca = learn_whiten(db, net, pooling=args.pooling, trfs=args.trfs, threads=args.threads) - with open(pca_path, 'wb') as f: - pkl.dump(net.pca, f) - + if args.whiten: + net.pca = net.pca[args.whiten] args.whiten = {'whitenp': args.whitenp, 'whitenv': args.whitenv, 'whitenm': args.whitenm} + else: + net.pca = None + args.whiten = None # Evaluate res = eval_model(dataset, net, args.trfs, pooling=args.pooling, gemp=args.gemp, detailed=args.detailed,