diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a70e808 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so diff --git a/README.md b/README.md new file mode 100644 index 0000000..d2af743 --- /dev/null +++ b/README.md @@ -0,0 +1,184 @@ +# Deep Image Retrieval + +This repository contains the models and the evaluation scripts (in Python3 and Pytorch 1.0) of the papers: + +**[1] End-to-end Learning of Deep Visual Representations for Image Retrieval** +Albert Gordo, Jon Almazan, Jerome Revaud, Diane Larlus, IJCV 2017 [\[PDF\]](https://arxiv.org/abs/1610.07940) + +**[2] Learning with Average Precision: Training Image Retrieval with a Listwise Loss** +Jerome Revaud, Rafael S. Rezende, Cesar de Souza, Jon Almazan, arXiv 2019 [\[PDF\]](https://es.naverlabs.com/jon-almazan/deep-image-retrieval/files/4381/paper.pdf) + + +Both papers tackle the problem of image retrieval and explore different ways to learn deep visual representations for this task. In both cases, a CNN is used to extract a feature map that is aggregated into a compact, fixed-length representation by a global-aggregation layer*. Finally, this representation is first projected with a FC, and then L2 normalized so images can be efficiently compared with the dot product. + + +![dir_network](https://es.naverlabs.com/storage/user/2353/files/f943154c-7736-11e9-83fd-bd0ab10db282) + +All components in this network, including the aggregation layer, are differentiable, which makes it end-to-end trainable for the end task. In [1], a Siamese architecture that combines three streams with a triplet loss was proposed to train this network. In [2], this work was extended by replacing the triplet loss with a new loss that directly optimizes for Average Precision. + +![Losses](https://es.naverlabs.com/storage/user/2353/files/f50571c4-82f2-11e9-8cf4-228334f7c335) + +\* Originally, [1] used R-MAC pooling [3] as the global-aggregation layer. However, due to its efficiency and better performace we have replaced the R-MAC pooling layer with the Generalized-mean pooling layer (GeM) proposed in [4]. You can find the original implementation of [1] in Caffe following [this link](https://europe.naverlabs.com/Research/Computer-Vision/Learning-Visual-Representations/Deep-Image-Retrieval/). + +## Pre-requisites + +In order to run this toolbox you will need: + +- Python3 (tested with Python 3.7.3) +- PyTorch (tested with version 1.0.1) +- The following packages: matplotlib, tqdm, scikit-learn + +With conda you can run the following commands: + +``` +conda install numpy matplotlib tqdm scikit-learn +conda install pytorch torchvision cudatoolkit=10.0 -c pytorch +``` + +## Installation + +``` +# Download the code +git clone git@es.naverlabs.com:jon-almazan/deep-image-retrieval.git + +# Create env variables +cd deep-image-retrieval +export DIR_ROOT=$PWD +export DB_ROOT=/PATH/TO/YOUR/DATASETS +# for example: export DB_ROOT=$PWD/dirtorch/data/datasets +``` + + +## Evaluation + + +### Pre-trained models + +The table below contains the pre-trained models that we provide with this library, together with their mAP performance on some of the most well-know image retrieval benchmakrs: [Oxford5K](http://www.robots.ox.ac.uk/~vgg/data/oxbuildings/), [Paris6K](http://www.robots.ox.ac.uk/~vgg/data/parisbuildings/), and their Revisited versions ([ROxford5K and RParis6K](https://github.com/filipradenovic/revisitop)). + + +| Model | Oxford5K | Paris6K | ROxford5K (med/hard) | RParis6K (med/hard) | +|--- |:-:|:-:|:-:|:-:| +| [Resnet101-TL-MAC]() | | 91.0| 63.6 / 37.1 | 76.7 / 55.7 | +| [Resnet101-TL-GeM]() | 85.5 | 93.4 | 64.8 / 41.6 | 78.9 / 59.4 | +| [Resnet50-AP-GeM]() | 87.9 | 91.9 | 65.8 / 41.7| 77.6 / 57.3 | +| [Resnet101-AP-GeM](https://bit.ly/2LGLbnj) | 89.3 | 93.0 | 67.4 / 42.8| 80.4/61.0 | +| [Resnet101-AP-GeM-LM18]()** | 88.4 | 93.0 | 66.5 / 43.1 | 80.2 / 60.4 | + + +The name of the model encodes the backbone architecture of the network and the loss that has been used to train it (TL for triplet loss and AP for Average Precision loss). All models use **Generalized-mean pooling (GeM)** [3] as the global pooling mechanism, except for the model in the first row that uses MAC [3] \(i.e. max-pooling), and have been trained on the **Landmarks-clean** [1] dataset (the clean version of the [Landmarks dataset](http://sites.skoltech.ru/compvision/projects/neuralcodes/)) directly **fine-tuning from ImageNet**. These numbers have been obtained using a **single resolution** and applying **whitening** to the output features (which has also been learned on Landmarks-clean). For a detailed explanation of all the hyper-parameters see [1] and [2] for the triplet loss and AP loss models, respectively. + +** For the sake of completeness, we have added an extra model, `Resnet101-AP-LM18`, which has been trained on the [Google-Landmarks Dataset](https://www.kaggle.com/google/google-landmarks-dataset), a large dataset consisting of more than 1M images and 15K classes. + +### Reproducing the results + +The script `test_dir.py` can be used to evaluate the pre-trained models provided and to reproduce the results above: + +``` +python -m dirtorch.test_dir --dataset DATASET --checkpoint PATH_TO_MODEL + [--whiten DATASET] [--whitenp POWER] [--aqe ALPHA-QEXP] + [--trfs TRANSFORMS] [--gpu ID] [...] +``` + +- `--dataset`: selects the dataset (eg.: Oxford5K, Paris6K, ROxford5K, RParis6K) [**required**] +- `--checkpoint`: path to the model weights [**required**] +- `--whiten`: applies whitening to the output features [default 'Landmarks_clean'] +- `--whitenp`: whitening power [default: 0.25] +- `--aqe`: alpha-query expansion parameters [default: None] +- `--trfs`: input image transformations (can be used to apply multi-scale) [default: None] +- `--gpu`: selects the GPU ID (-1 selects the CPU) + +For example, to reproduce the results of the Resnet101-AP_loss model on the RParis6K dataset run: + +``` +cd $DIR_ROOT +export DB_ROOT=/PATH/TO/YOUR/DATASETS + +mkdir -p dirtorch/data/models +wget https://bit.ly/2LGLbnj -O model.tgz +tar -C dirtorch/data/models -xzf model.tgz && rm model.tgz + +python -m dirtorch.test_dir --dataset RParis6K + --checkpoint dirtorch/data/models/resnet101_APloss_gem.pt + --whiten Landmarks_clean --whitenp 0.25 --gpu 0 +``` + +And you should see the following output: + +``` +>> Evaluation... + top1 not implemented! + * mAP-easy = 0.911001 + * mAP-medium = 0.80115 + * mAP-hard = 0.604583 +``` + +**Note:** this script integrates an automatic downloader for the Oxford5K, Paris6K, ROxford5K, and RParis6K datasets (kudos to Filip Radenovic ;)). The datasets will be saved in `$DB_ROOT`. + +## Feature extractor + +You can also use the pre-trained models to extract features from your own datasets or collection of images. For that we provide the script `feature_extractor.py`: + +``` +python -m dirtorch.extract_features --dataset DATASET --checkpoint PATH_TO_MODEL + --output PATH_TO_FILE [--whiten DATASET] [--whitenp POWER] + [--trfs TRANSFORMS] [--gpu ID] [...] +``` + +where `--output` is used to specify the destination where the features will be saved. The rest of the parameters are the same as seen above. + +The library provides a generic class dataset (`ImageList`) that allows you to specify the list of images by providing a simple text file. + +``` +--dataset 'ImageList("PATH_TO_TEXTFILE" [, "IMAGES_ROOT"])' +``` + +Each row of the text file should contain a single path to a given image: + +``` +/PATH/TO/YOUR/DATASET/images/image1.jpg +/PATH/TO/YOUR/DATASET/images/image2.jpg +/PATH/TO/YOUR/DATASET/images/image3.jpg +/PATH/TO/YOUR/DATASET/images/image4.jpg +/PATH/TO/YOUR/DATASET/images/image5.jpg +``` + +Alternatively, you can also use relative paths, and use `IMAGES_ROOT` to specify the root folder. + + +## Citations + +Please consider citing the following papers in your publications if this helps your research. + +``` +@article{GARL17, + title = {End-to-end Learning of Deep Visual Representations for Image Retrieval}, + author = {Gordo, A. and Almazan, J. and Revaud, J. and Larlus, D.} + journal = {IJCV}, + year = {2017} +} + +@inproceedings{RARS19, + title = {Learning with Average Precision: Training Image Retrieval with a Listwise Loss}, + author = {Revaud, J. and Almazan, J. and Rezende, R.S. and de Souza, C.R.} + booktitle = {ArXiv}, + year = {2019} +} +``` + +## Contributors + +This library has been developed by Jerome Revaud, Rafael de Rezende, Cesar de Souza, Diane Larlus, and Jon Almazan at **[Naver Labs Europe](https://europe.naverlabs.com)**. + + +**Special thanks** to Filip Radenovic. We have used in this library the ROxford5K and RParis6K downloader from his awesome **[CNN-imageretrieval repository](https://github.com/filipradenovic/cnnimageretrieval-pytorch)**. Consider checking it out if you want to train your own models for image retrieval! + +## References + +[1] Gordo, A., Almazan, J., Revaud, J., Larlus, D., [End-to-end Learning of Deep Visual Representations for Image Retrieval](https://arxiv.org/abs/1610.07940). IJCV 2017 + +[2] Revaud, J., de Souza, C., Rezende, R.S., Almazan, J., [Learning with Average Precision: Training Image Retrieval with a Listwise Loss](). ArXiv 2019 + +[3] Tolias, G., Sicre, R., Jegou, H., [Particular object retrieval with integral max-pooling of CNN activations](https://arxiv.org/abs/1511.05879). ICLR 2016 + +[4] Radenovic, F., Tolias, G., Chum, O., [Fine-tuning CNN Image Retrieval with No Human Annotation](https://arxiv.org/pdf/1711.02512). TPAMI 2018 diff --git a/dirtorch/__init__.py b/dirtorch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dirtorch/datasets/__init__.py b/dirtorch/datasets/__init__.py new file mode 100644 index 0000000..eb4fe06 --- /dev/null +++ b/dirtorch/datasets/__init__.py @@ -0,0 +1,17 @@ +try: from .oxford import * +except ImportError: pass +try: from .paris import * +except ImportError: pass +try: from .distractors import * +except ImportError: pass +try: from .landmarks import Landmarks_clean, Landmarks_clean_val, Landmarks_lite +except ImportError: pass +try: from .landmarks18 import * +except ImportError: pass + +# create a dataset from a string +from .create import * +create = DatasetCreator(globals()) + +from .dataset import split, deploy, deploy_and_split +from .generic import * diff --git a/dirtorch/datasets/__main__.py b/dirtorch/datasets/__main__.py new file mode 100644 index 0000000..29b52e1 --- /dev/null +++ b/dirtorch/datasets/__main__.py @@ -0,0 +1,83 @@ +import os +import sys +import pdb +from nltools.gutils.pyplot import * + + +def viz_dataset(db, nr=6, nc=6): + ''' a convenient way to vizualize the content of a dataset. + If there are queries, it will show the ground-truth for each query. + ''' + pyplot(globals()) + + try: + query_db = db.get_query_db() + + LABEL = ['null', 'pos', 'neg'] + + for query in range(query_db.nimg): + figure("Query") + subplot_grid(20, 1) + pl.imshow(query_db.get_image(query)) + pl.xlabel('Query #%d' % query) + + pl_noticks() + gt = db.get_query_groundtruth(query) + ranked = (-gt).argsort() + + for i,idx in enumerate(ranked): + if i+2 > 20: break + subplot_grid(20, i+2) + pl.imshow(db.get_image(idx)) + label = gt[idx] + pl.xlabel('#%d %s' % (idx, LABEL[label])) + pl_noticks() + pdb.set_trace() + + except NotImplementedError: + import numpy as np + pl.ion() + + R = 1 + nr = nr // R + + def show_img(r, c, i): + i = np.random.randint(len(db)) + + pl.subplot(R*nr,nc,(R*r+0)*nc+c+1) + img = db.get_image(i) + print(i, db.get_key(i), "%d x %d" % img.size) + pl.imshow(img) + pl.xticks(()) + pl.yticks(()) + if db.has_label(): + pl.xlabel(db.get_label(i)) + + pl.figure() + pl.subplots_adjust(0,0,1,1,0.02,) + n = 0 + while True: + pl.clf() + for r in range(nr): + for c in range(nc): + show_img(r,c,n) + n += 1 + pdb.set_trace() + + + +if __name__ == '__main__': + from .__init__ import create + + args = sys.argv[1:] + if not args: + print("Error: Provide a db_cmd to this script"); + exit() + + db = args.pop(0) + print("Instanciating dataset '%s'..." % db) + + db = create(db) + print(db) + + viz_dataset(db) diff --git a/dirtorch/datasets/create.py b/dirtorch/datasets/create.py new file mode 100644 index 0000000..a8f028f --- /dev/null +++ b/dirtorch/datasets/create.py @@ -0,0 +1,30 @@ +from .dataset import split, deploy, deploy_and_split +from .generic import * + + +class DatasetCreator: + ''' Create a dataset from a string. + + dataset_cmd (str): + Command to execute. + ex: "ImageList('path/to/list.txt')" + + Returns: + instanciated dataset. + ''' + def __init__(self, globs): + for k, v in globs.items(): + globals()[k] = v + + def __call__(self, dataset_cmd ): + if '(' not in dataset_cmd: + dataset_cmd += "()" + + try: + return eval(dataset_cmd) + except NameError: + import sys, inspect + dbs = [name for name,obj in globals().items() if name[0]!='_' and name not in ('DatasetCreator','defaultdict') and inspect.isclass(obj)] + print("Error: unknown dataset %s\nAvailable datasets: %s" % (dataset_cmd.replace('()',''), ', '.join(sorted(dbs))), file=sys.stderr) + sys.exit(1) + diff --git a/dirtorch/datasets/dataset.py b/dirtorch/datasets/dataset.py new file mode 100644 index 0000000..f4f0603 --- /dev/null +++ b/dirtorch/datasets/dataset.py @@ -0,0 +1,587 @@ +import os +import json +import pdb +import numpy as np +from collections import defaultdict + + +class Dataset(object): + ''' Base class for a dataset. To be overloaded. + + Contains: + - images --> get_image(i) --> image + - image labels --> get_label(i) + - list of image queries --> get_query(i) --> image + - list of query ROIs --> get_query_roi(i) + + Creation: + Use dataset.create( "..." ) to instanciate one. + db = dataset.create( "ImageList('path/to/list.txt')" ) + + Attributes: + root: image directory root + nimg: number of images == len(self) + nclass: number of classes + ''' + root = '' + img_dir = '' + nimg = 0 + nclass = 0 + ninstance = 0 + + classes = [] # all class names (len == nclass) + labels = [] # all image labels (len == nimg) + c_relevant_idx = {} # images belonging to each class (c_relevant_idx[cl_name] = [idx list]) + + def __len__(self): + return self.nimg + + def get_filename(self, img_idx, root=None): + return os.path.join(root or self.root, self.img_dir, self.get_key(img_idx)) + + def get_key(self, img_idx): + raise NotImplementedError() + + def key_to_index(self, key): + if not hasattr(self, '_key_to_index'): + self._key_to_index = {self.get_key(i):i for i in range(len(self))} + return self._key_to_index[key] + + def get_image(self, img_idx, resize=None): + from PIL import Image + img = Image.open(self.get_filename(img_idx)).convert('RGB') + if resize: + img = img.resize(resize, Image.ANTIALIAS if np.prod(resize) < np.prod(img.size) else Image.BICUBIC) + return img + + def get_image_size(self, img_idx): + return self.imsize + + def get_label(self, img_idx, toint=False): + raise NotImplementedError() + + def has_label(self): + try: self.get_label(0); return True + except NotImplementedError: return False + + def get_query_db(self): + raise NotImplementedError() + + def get_query_groundtruth(self, query_idx, what='AP'): + query_db = self.get_query_db() + assert self.nclass == query_db.nclass + if what == 'AP': + res = -np.ones(self.nimg, dtype=np.int8) # all negatives + res[self.c_relevant_idx[query_db.get_label(query_idx)]] = 1 # positives + if query_db == self: res[query_idx] = 0 # query is junk + elif what == 'label': + res = query_db.get_label(query_idx) + else: + raise ValueError("Unknown ground-truth type: %s" % what) + return res + + def eval_query_AP(self, query_idx, scores): + """ Evaluates AP for a given query. + """ + from ..utils.evaluation import compute_AP + gt = self.get_query_groundtruth(query_idx, 'AP') # labels in {-1, 0, 1} + assert gt.shape == scores.shape, "scores should have shape %s" % str(gt.shape) + assert -1 <= gt.min() and gt.max() <= 1, "bad ground-truth labels" + keep = (gt != 0) # remove null labels + if sum(gt[keep]>0) == 0: return -1 # exclude queries with no relevants form the evaluation + return compute_AP(gt[keep]>0, scores[keep]) + + def eval_query_top(self, query_idx, scores, k=(1,5,10,20,50,100)): + """ Evaluates top-k for a given query. + """ + if not self.labels: raise NotImplementedError() + q_label = self.get_query_groundtruth(query_idx, 'label') + correct = np.bool8([l==q_label for l in self.labels]) + correct = correct[(-scores).argsort()] + return {k_:float(correct[:k_].any()) for k_ in k if k_ 0: + train.append( imgs.pop() ) # ensure at least 1 training sample + for i in range(int(0.9999+val_prop*nn)): + if imgs: val.append( imgs.pop() ) + for i in range(int(0.9999+test_prop*nn)): + if imgs: test.append( imgs.pop() ) + if imgs: train += imgs + + train.sort() + val.sort() + test.sort() + + elif method == 'hash': + val_prop2 = train_prop + val_prop + for i in range(len(dataset)): + fname = dataset.get_key(i) + + # compute file hash + h = (hash(fname)%100)/100.0 + if h < train_prop: + train.append( i ) + elif h < val_prop2: + val.append( i ) + else: + test.append( i ) + else: + raise ValueError("bad split method "+method) + + train = SubDataset(dataset, train) + val = SubDataset(dataset, val) + test = SubDataset(dataset, test) + + if val_prop == 0: + return train, test + else: + return train, val, test + + +class SubDataset(Dataset): + ''' Contains a sub-part of another dataset. + ''' + def __init__(self, dataset, indices): + self.root = dataset.root + self.img_dir = dataset.img_dir + self.dataset = dataset + self.indices = indices + + self.nimg = len(self.indices) + self.nclass = self.dataset.nclass + + def get_key(self, i): + return self.dataset.get_key(self.indices[i]) + + def get_label(self, i, **kw): + return self.dataset.get_label(self.indices[i],**kw) + + def get_bbox(self, i, **kw): + if hasattr(self.dataset,'get_bbox'): + return self.dataset.get_bbox(self.indices[i],**kw) + else: + raise NotImplementedError() + + def __repr__(self): + res = 'SubDataset(%s)\n' % self.dataset.__class__.__name__ + res += ' %d/%d images, %d classes\n' % (len(self),len(self.dataset),self.nclass) + res += ' root: %s...' % os.path.join(self.root,self.img_dir) + return res + + def viz_distr(self): + from matplotlib import pyplot as pl; pl.ion() + count = [0]*self.nclass + for i in range(self.nimg): + count[ self.get_label(i,toint=True) ] += 1 + cid = list(range(self.nclass)) + pl.bar(cid, count) + pdb.set_trace() + + +class CatDataset(Dataset): + ''' Concatenation of several datasets. + ''' + def __init__(self, *datasets): + assert len(datasets) >= 1 + self.datasets = datasets + + db = datasets[0] + self.root = os.path.normpath(os.path.join(db.root, db.img_dir)) + os.sep + self.labels = self.imgs = None # cannot access it the normal way + self.classes = db.classes + self.nclass = db.nclass + self.c_relevant_idx = defaultdict(list) + + offsets = [0] + full_root = lambda db: os.path.normpath(os.path.join(db.root, db.img_dir)) + for db in datasets: + assert db.nclass == self.nclass, 'All dataset must have the same number of classes' + assert db.classes == self.classes, 'All datasets must have the same classes' + + # look for a common root + self.root = os.path.commonprefix((self.root, full_root(db) + os.sep)) + assert self.root, 'no common root between datasets' + self.root = self.root[:self.root.rfind(os.sep)] + os.sep + + offset = sum(offsets) + for label, rel in db.c_relevant_idx.items(): + self.c_relevant_idx[label] += [i+offset for i in rel] + offsets.append(db.nimg) + + self.roots = [full_root(db)[len(self.root):] for db in datasets] + self.offsets = np.cumsum(offsets) + self.nimg = self.offsets[-1] + + def which(self, i): + pos = np.searchsorted(self.offsets, i, side='right')-1 + assert pos < self.nimg, 'Bad image index %d >= %d' % (i, self.nimg) + return pos, i - self.offsets[pos] + + def get(self, i, attr): + b, j = self.which(i) + return getattr(self.datasets[b],attr) + + def __getattr__(self, name): + # try getting it + val = getattr(self.datasets[0], name) + assert not callable(val), 'CatDataset: %s is not a shared attribute, use call()' % str(name) + for db in self.datasets[1:]: + assert np.all(val == getattr(db, name)), 'CatDataset: inconsistent shared attribute %s, use get()' % str(name) + return val + + def call(self, i, func, *args, **kwargs): + b, j = self.which(i) + return getattr(self.datasets[b],attr)(j,*args, **kwargs) + + def get_key(self, i): + b, i = self.which(i) + key = self.datasets[b].get_key(i) + return os.path.join(self.roots[b], key) + + def get_label(self, i, toint=False): + b, i = self.which(i) + return self.datasets[b].get_label(i,toint=toint) + + def get_bbox(self,i): + b, i = self.which(i) + return self.datasets[b].get_bbox(i) + + def get_polygons(self,i,**kw): + b, i = self.which(i) + return self.datasets[b].get_polygons(i,**kw) + + + + +def deploy( dataset, target_dir, transforms=None, redo=False, ext=None, **savekwargs): + if not target_dir: return dataset + from PIL import Image + from fcntl import flock, LOCK_EX + import tqdm + + if transforms is not None: + # identify transform with a unique hash + import hashlib + def get_params(trf): + if type(trf).__name__ == 'Compose': + return [get_params(t) for t in trf.transforms] + else: + return {type(trf).__name__:vars(trf)} + params = get_params(transforms) + unique_key = json.dumps(params, sort_keys=True).encode('utf-8') + h = hashlib.md5().hexdigest() + target_dir = os.path.join(target_dir, h) + print("Deploying in '%s'" % target_dir) + + try: + imsizes_path = os.path.join(target_dir,'imsizes.json') + imsize_file = open(imsizes_path,'r+') + #print("opening %s in r+ mode"%imsize_file) + except IOError: + try: os.makedirs(os.path.split(imsizes_path)[0]) + except OSError: pass + imsize_file = open(imsizes_path,'w+') + #print("opening %s in w+ mode"%imsize_file) + + # block access to this file, only one process can continue + from time import time as now + t0 = now() + flock(imsize_file, LOCK_EX) + #print("exclusive access lock for %s after %ds"%(imsize_file,now()-t0)) + + try: + imsizes = json.load(imsize_file) + imsizes = {img:tuple(size) for img,size in imsizes.items()} + except: + imsizes = {} + + def check_one_image(i): + key = dataset.get_key(i) + target = os.path.join(target_dir, key) + if ext: target = os.path.splitext(target)[0]+'.'+ext + + updated = 0 + if redo or (not os.path.isfile(target)) or key not in imsizes: + # load image and transform it + img = Image.open(dataset.get_filename(i)).convert('RGB') + imsizes[key] = img.size + + if transforms is not None: + img = transforms(img) + + odir = os.path.split( target )[0] + try: os.makedirs(odir) + except FileExistsError: pass + img.save( target, **savekwargs ) + + updated = 1 + if (i % 100) == 0: + imsize_file.seek(0) # goto begining + json.dump(dict(imsizes), imsize_file) + imsize_file.truncate() + updated = 0 + + return updated + + from nltools.gutils import job_utils + for i in range(len(dataset)): + updated = check_one_image(i) # first try without any threads + if updated: break + if i+1 < len(dataset): + updated += sum(job_utils.parallel_threads(range(i+1,len(dataset)), check_one_image, + desc='Deploying dataset', n_threads=0, front_num=0)) + + if updated: + imsize_file.seek(0) # goto begining + json.dump(dict(imsizes), imsize_file) + imsize_file.truncate() + imsize_file.close() # now, other processes can access too + + return DeployedDataset(dataset, target_dir, imsizes, trfs=transforms, ext=ext) + + + +class DeployedDataset(Dataset): + '''Just a deployed dataset with a different root and image extension. + ''' + def __init__(self, dataset, root, imsizes=None, trfs=None, ext=None): + self.dataset = dataset + if root[-1] != '/': root += '/' + self.root = root + self.ext = ext + self.imsizes = imsizes or json.load(open(root+'imsizes.json')) + self.trfs = trfs or (lambda x: x) + assert isinstance(self.imsizes, dict) + assert len(self.imsizes) >= dataset.nimg, pdb.set_trace() + + self.nimg = dataset.nimg + self.nclass = dataset.nclass + + self.labels = dataset.labels + self.c_relevant_idx = dataset.c_relevant_idx + #self.c_non_relevant_idx = dataset.c_non_relevant_idx + + self.get_label = dataset.get_label + self.classes = dataset.classes + if '/query_db/' not in root: + try: + query_db = dataset.get_query_db() + if query_db is not dataset: + self.query_db = deploy(query_db, os.path.join(root,'query_db'), transforms=trfs, ext=ext) + self.get_query_db = lambda: self.query_db + except NotImplementedError: + pass + self.get_query_groundtruth = dataset.get_query_groundtruth + if hasattr(dataset, 'eval_query_AP'): + self.eval_query_AP = dataset.eval_query_AP + + if hasattr(dataset, 'true_pairs'): + self.true_pairs = dataset.true_pairs + self.get_false_pairs = dataset.get_false_pairs + + def __repr__(self): + res = self.dataset.__repr__() + res += ' deployed at %s/...%s' % (self.root, self.ext or '') + return res + + def __len__(self): + return self.nimg + + def get_key(self, i): + key = self.dataset.get_key(i) + if self.ext: key = os.path.splitext(key)[0]+'.'+self.ext + return key + + def get_something(self, what, i, *args, **fmt): + try: + get_func = getattr(self.dataset, 'get_'+what) + except AttributeError: + raise NotImplementedError() + imsize = self.imsizes[self.dataset.get_key(i)] + sth = get_func(i,*args,**fmt) + return self.trfs({'imsize':imsize, what:sth})[what] + + def get_bbox(self, i, **kw): + return self.get_something('bbox', i, **kw) + + def get_polygons(self, i, *args, **kw): + return self.get_something('polygons', i, *args, **kw) + + def get_label_map(self, i, *args, **kw): + assert 'polygons' in kw, "you need to supply polygons because image has been transformed" + return self.dataset.get_label_map(i, *args, **kw) + def get_instance_map(self, i, *args, **kw): + assert 'polygons' in kw, "you need to supply polygons because image has been transformed" + return self.dataset.get_instance_map(i, *args, **kw) + def get_angle_map(self, i, *args, **kw): + assert 'polygons' in kw, "you need to supply polygons because image has been transformed" + return self.dataset.get_angle_map(i, *args, **kw) + + def original(self): + return self.dataset + + + +def deploy_and_split( trainset, deploy_trf=None, deploy_dir='/dev/shm', + valset=None, split_val=0.0, + img_ext='jpg', img_quality=95, + **_useless ): + ''' Deploy and split a dataset into train / val. + if valset is not provided, then trainset is automatically split into train/val + based on the split_val proportion. + ''' + # first, deploy the training set + traindb = deploy( trainset, deploy_dir, transforms=deploy_trf, ext=img_ext, quality=img_quality ) + + if valset: + # load a validation db + valdb = deploy( valset, deploy_dir, transforms=deploy_trf, ext=img_ext, quality=img_quality ) + + else: + if split_val > 0: + # automatic split in train/val + traindb, valdb = split( traindb, train_prop=1-split_val ) + else: + valdb = None + + print( "\n>> Training set:" ); print( traindb ) + print( "\n>> Validation set:" ); print( valdb ) + return traindb, valdb + + + + +class CropDataset(Dataset): + """list_of_imgs_and_crops = [(img_key, (l, t, r, b)), ...] + """ + def __init__(self, dataset, list_of_imgs_and_crops): + self.dataset = dataset + self.root = dataset.root + self.img_dir = dataset.img_dir + self.imgs, self.crops = zip(*list_of_imgs_and_crops) + self.nimg = len(self.imgs) + + def get_image(self, img_idx): + # even if the image have multiple signage polygon? + org_img = dataset.get_image(self, img_idx) + crop_signs = crop_image(org_img, self.crops[img_idx]) + + return crop_signs[0] # temporary use one, but have to change for multiple signages + + def get_filename(self, img_idx): + return self.dataset.get_filename(img_idx) + + def get_key(self, img_idx): + return self.dataset.get_key(img_idx) + + def crop_image(self, img, polygons): + import cv2 + crop_signs=[] + if len(polygons)==0: + pdb.set_trace() + + for Polycc in polygons: + rgbimg = img.copy() + rgbimg = np.array(rgbimg) # pil to cv2 + Poly_s = np.array(Polycc) + + ## rearrange + if Poly_s[0, 1]Poly_s[3, 1]: + temp = Poly_s[3, :].copy() + Poly_s[3, :]= Poly_s[2, :] + Poly_s[2, :]=temp + + cy_s = np.mean( Poly_s[:,0] ) + cx_s = np.mean( Poly_s[:,1] ) + w_s = np.abs( Poly_s[0][1]-Poly_s[1][1] ) + h_s = np.abs( Poly_s[0][0]-Poly_s[2][0] ) + Poly_d = np.array([(cy_s-h_s/2, cx_s+w_s/2), (cy_s-h_s/2, cx_s-w_s/2), (cy_s+h_s/2, cx_s-w_s/2), (cy_s+h_s/2, cx_s+w_s/2)]).astype(np.int) + + M, mask= cv2.findHomography(Poly_s, Poly_d) + + warpimg = Image.fromarray(cv2.warpPerspective(rgbimg, M, (645,800))) # from cv2 type rgbimg + crop_sign = warpimg.crop([np.min(Poly_d[:,0]), np.min(Poly_d[:,1]), np.max(Poly_d[:,0]), np.max(Poly_d[:,1])]) + + ### append + crop_signs.append(crop_sign) + + return crop_signs + + + + + + + + + + + + + + diff --git a/dirtorch/datasets/downloader.py b/dirtorch/datasets/downloader.py new file mode 100644 index 0000000..5cdaff1 --- /dev/null +++ b/dirtorch/datasets/downloader.py @@ -0,0 +1,52 @@ +import os +import os.path as osp + +DB_ROOT = os.environ['DB_ROOT'] + +def download_dataset(dataset): + if not os.path.isdir(DB_ROOT): + os.makedirs(DB_ROOT) + + dataset = dataset.lower() + if dataset in ('oxford5k', 'roxford5k'): + src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' + dl_files = ['oxbuild_images.tgz'] + dir_name = 'oxford5k' + elif dataset in ('paris6k', 'rparis6k'): + src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' + dl_files = ['paris_1.tgz', 'paris_2.tgz'] + dir_name = 'paris6k' + else: + raise ValueError('Unknown dataset: {}!'.format(dataset)) + + dst_dir = os.path.join(DB_ROOT, dir_name, 'jpg') + if not os.path.isdir(dst_dir): + print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) + os.makedirs(dst_dir) + for dli in range(len(dl_files)): + dl_file = dl_files[dli] + src_file = os.path.join(src_dir, dl_file) + dst_file = os.path.join(dst_dir, dl_file) + print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file)) + os.system('wget {} -O {}'.format(src_file, dst_file)) + print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file)) + # create tmp folder + dst_dir_tmp = os.path.join(dst_dir, 'tmp') + os.system('mkdir {}'.format(dst_dir_tmp)) + # extract in tmp folder + os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp)) + # remove all (possible) subfolders by moving only files in dst_dir + os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir)) + # remove tmp folder + os.system('rm -rf {}'.format(dst_dir_tmp)) + print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file)) + os.system('rm {}'.format(dst_file)) + + gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'test', dataset) + gnd_dst_dir = os.path.join(DB_ROOT, dir_name) + gnd_dl_file = 'gnd_{}.pkl'.format(dataset) + gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file) + gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file) + if not os.path.exists(gnd_dst_file): + print('>> Downloading dataset {} ground truth file...'.format(dataset)) + os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file)) diff --git a/dirtorch/datasets/generic.py b/dirtorch/datasets/generic.py new file mode 100644 index 0000000..c0915ad --- /dev/null +++ b/dirtorch/datasets/generic.py @@ -0,0 +1,281 @@ +import os +import json +import pdb +import numpy as np +import pickle +import os.path as osp +import json + +from .dataset import Dataset +from .generic_func import * + + +class ImageList(Dataset): + ''' Just a list of images (no labels, no query). + + Input: text file, 1 image path per row + ''' + def __init__(self, img_list_path, root='', imgs=None): + self.root = root + if imgs is not None: + self.imgs = imgs + else: + self.imgs = [e.strip() for e in open(img_list_path)] + + self.nimg = len(self.imgs) + self.nclass = 0 + self.nquery = 0 + + def get_key(self, i): + return self.imgs[i] + + +class LabelledDataset (Dataset): + """ A dataset with per-image labels + and some convenient functions. + """ + def find_classes(self, *arg, **cls_idx): + labels = arg[0] if arg else self.labels + self.classes, self.cls_idx = find_and_list_classes(labels, cls_idx=cls_idx) + self.nclass = len(self.classes) + self.c_relevant_idx = find_relevants(self.labels) + + +class ImageListLabels(LabelledDataset): + ''' Just a list of images with labels (no queries). + + Input: text file, 1 image path and label per row (space-separated) + ''' + def __init__(self, img_list_path, root=None): + self.root = root + if osp.splitext(img_list_path)[1] == '.txt': + tmp = [e.strip() for e in open(img_list_path)] + self.imgs = [e.split(' ')[0] for e in tmp] + self.labels = [e.split(' ')[1] for e in tmp] + elif osp.splitext(img_list_path)[1] == '.json': + d = json.load(open(img_list_path)) + self.imgs = [] + self.labels = [] + for i, l in d.items(): + self.imgs.append(i) + self.labels.append(l) + self.find_classes() + + self.nimg = len(self.imgs) + self.nquery = 0 + + def get_key(self, i): + return self.imgs[i] + + def get_label(self, i, toint=False): + label = self.labels[i] + if toint: label = self.cls_idx[ label ] + return label + + def get_query_db(self): + return self + +class ImageListLabelsQ(ImageListLabels): + ''' Two list of images with labels: one for the dataset and one for the queries. + + Input: text file, 1 image path and label per row (space-separated) + ''' + def __init__(self, img_list_path, query_list_path, root=None): + self.root = root + tmp = [e.strip() for e in open(img_list_path)] + self.imgs = [e.split(' ')[0] for e in tmp] + self.labels = [e.split(' ')[1] for e in tmp] + tmp = [e.strip() for e in open(query_list_path)] + self.qimgs = [e.split(' ')[0] for e in tmp] + self.qlabels = [e.split(' ')[1] for e in tmp] + self.find_classes() + + self.nimg = len(self.imgs) + self.nquery = len(self.qimgs) + + def find_classes(self, *arg, **cls_idx): + labels = arg[0] if arg else self.labels + self.qlabels + self.classes, self.cls_idx = find_and_list_classes(labels, cls_idx=cls_idx) + self.nclass = len(self.classes) + self.c_relevant_idx = find_relevants(self.labels) + + def get_query_db(self): + return ImagesAndLabels(self.qimgs, self.qlabels, self.cls_idx, root=self.root) + + +class ImagesAndLabels(ImageListLabels): + ''' Just a list of images with labels. + + Input: two arrays containing the text file, 1 image path and label per row (space-separated) + ''' + def __init__(self, imgs, labels, cls_idx, root=None): + self.root = root + self.imgs = imgs + self.labels = labels + self.cls_idx = cls_idx + self.nclass = len(self.cls_idx.keys()) + + self.nimg = len(self.imgs) + self.nquery = 0 + + +class ImageListRelevants(Dataset): + """ A dataset composed by a list of images, a list of indices used as queries, + and for each query a list of relevant and junk indices (ie. Oxford-like GT format) + + Input: path to the pickle file + """ + def __init__(self, gt_file, root=None, img_dir = 'jpg', ext='.jpg'): + self.root = root + self.img_dir = img_dir + + with open(gt_file, 'rb') as f: + gt = pickle.load(f) + self.imgs = [osp.splitext(e)[0] + (osp.splitext(e)[1] if osp.splitext(e)[1] else ext) for e in gt['imlist']] + self.qimgs = [osp.splitext(e)[0] + (osp.splitext(e)[1] if osp.splitext(e)[1] else ext) for e in gt['qimlist']] + self.qroi = [tuple(e['bbx']) for e in gt['gnd']] + if 'ok' in gt['gnd'][0]: + self.relevants = [e['ok'] for e in gt['gnd']] + else: + self.relevants = None + self.easy = [e['easy'] for e in gt['gnd']] + self.hard = [e['hard'] for e in gt['gnd']] + self.junk = [e['junk'] for e in gt['gnd']] + + self.nimg = len(self.imgs) + self.nquery = len(self.qimgs) + + def get_relevants(self, qimg_idx, mode='classic'): + if mode=='classic': rel = self.relevants[qimg_idx] + elif mode=='easy': rel = self.easy[qimg_idx] + elif mode=='medium': rel = self.easy[qimg_idx] + self.hard[qimg_idx] + elif mode=='hard': rel = self.hard[qimg_idx] + return rel + + def get_junk(self, qimg_idx, mode='classic'): + if mode=='classic': junk = self.junk[qimg_idx] + elif mode=='easy': junk = self.junk[qimg_idx] + self.hard[qimg_idx] + elif mode=='medium': junk = self.junk[qimg_idx] + elif mode=='hard': junk = self.junk[qimg_idx] + self.easy[qimg_idx] + return junk + + def get_query_filename(self, qimg_idx, root=None): + return os.path.join(root or self.root, self.img_dir, self.get_query_key(qimg_idx)) + + def get_query_roi(self, qimg_idx): + return self.qroi[qimg_idx] + + def get_key(self, i): + return self.imgs[i] + + def get_query_key(self, i): + return self.qimgs[i] + + def get_query_db(self): + return ImageListROIs(self.root, self.img_dir, self.qimgs, self.qroi) + + def get_query_groundtruth(self, query_idx, what='AP', mode='classic'): + res = -np.ones(self.nimg, dtype=np.int8) # all negatives + res[self.get_relevants(query_idx, mode)] = 1 # positives + res[self.get_junk(query_idx, mode)] = 0 # junk + return res + + def eval_query_AP(self, query_idx, scores): + """ Evaluates AP for a given query. + """ + from ..utils.evaluation import compute_AP + if self.relevants: + gt = self.get_query_groundtruth(query_idx, 'AP') # labels in {-1, 0, 1} + if gt.shape != scores.shape: + # TODO: Get this number in a less hacky way. This was the number of non-corrupted distractors + gt = np.concatenate([gt, np.full((976089,), fill_value=-1)]) + assert gt.shape == scores.shape, "scores should have shape %s" % str(gt.shape) + assert -1 <= gt.min() and gt.max() <= 1, "bad ground-truth labels" + keep = (gt != 0) # remove null labels + return compute_AP(gt[keep]>0, scores[keep]) + else: + d = {} + for mode in ('easy', 'medium', 'hard'): + gt = self.get_query_groundtruth(query_idx, 'AP', mode) # labels in {-1, 0, 1} + if gt.shape != scores.shape: + # TODO: Get this number in a less hacky way. This was the number of non-corrupted distractors + gt = np.concatenate([gt, np.full((976089,), fill_value=-1)]) + assert gt.shape == scores.shape, "scores should have shape %s" % str(gt.shape) + assert -1 <= gt.min() and gt.max() <= 1, "bad ground-truth labels" + keep = (gt != 0) # remove null labels + if sum(gt[keep]>0) == 0: #exclude queries with no relevants from the evaluation + d[mode] = -1 + else: + d[mode] = compute_AP(gt[keep]>0, scores[keep]) + return d + + +class ImageListROIs(Dataset): + def __init__(self, root, img_dir, imgs, rois): + self.root = root + self.img_dir = img_dir + self.imgs = imgs + self.rois = rois + + self.nimg = len(self.imgs) + self.nclass = 0 + self.nquery = 0 + + def get_key(self, i): + return self.imgs[i] + + def get_roi(self, i): + return self.rois[i] + + def get_image(self, img_idx, resize=None): + from PIL import Image + img = Image.open(self.get_filename(img_idx)).convert('RGB') + img = img.crop(self.rois[img_idx]) + if resize: + img = img.resize(resize, Image.ANTIALIAS if np.prod(resize) < np.prod(img.size) else Image.BICUBIC) + return img + +def not_none(label): + return label is not None + + +class ImageClusters(LabelledDataset): + ''' Just a list of images with labels (no query). + + Input: JSON, dict of {img_path:class, ...} + ''' + def __init__(self, json_path, root=None, filter=not_none): + self.root = root + self.imgs = [] + self.labels = [] + if isinstance(json_path, dict): + data = json_path + else: + data = json.load(open(json_path)) + assert isinstance(data, dict), 'json content is not a dictionary' + + for img, cls in data.items(): + assert type(img) is str + if not filter(cls): continue + if type(cls) not in (str,int,type(None)): continue + self.imgs.append( img ) + self.labels.append( cls ) + + self.find_classes() + self.nimg = len(self.imgs) + self.nquery = 0 + + def get_key(self, i): + return self.imgs[i] + + def get_label(self, i, toint=False): + label = self.labels[i] + if toint: label = self.cls_idx[ label ] + return label + + +class NullCluster(ImageClusters): + ''' Select only images with null label + ''' + def __init__(self, json_path, root=None): + ImageClusters.__init__(self, json_path, root, lambda c: c is None) diff --git a/dirtorch/datasets/generic_func.py b/dirtorch/datasets/generic_func.py new file mode 100644 index 0000000..c8fa190 --- /dev/null +++ b/dirtorch/datasets/generic_func.py @@ -0,0 +1,62 @@ +''' Generic functions for Dataset() class +''' +import pdb +import numpy as np +from collections import defaultdict + + +def find_and_list_classes(labels, cls_idx=None ): + ''' Given a list of image labels, deduce the list of classes. + + Parameters: + ----------- + labels : list + per-image labels (can be str, int, ...) + + cls_idx : dict or None + + Returns: + -------- + classes = [class0_name, class1_name, ...] + cls_idx = {class_name : class_index} + ''' + assert not isinstance(labels, set), 'labels must be ordered' + if not cls_idx: cls_idx = {} # don't put it as default arg!! + + # map string labels to integers + uniq_labels = set(labels) + nlabels = len(uniq_labels) + for label in cls_idx: + assert label in uniq_labels, "error: missing forced label '%s'" % str(label) + nlabels += (label not in uniq_labels) # one other label + + classes = {idx:cls for cls,idx in cls_idx.items()} + remaining_labels = set(range(nlabels)) - set(cls_idx.values()) + for cls in labels: + if cls in cls_idx: continue # already there + cls_idx[cls] = i = remaining_labels.pop() + classes[cls_idx[cls]] = cls + + assert min(classes.keys()) == 0 and len(classes) == max(classes.keys()) + 1 # no holes between integers + classes = [classes[c] for c in range(len(classes))] # dict --> list + + return classes, cls_idx + + +def find_relevants(labels): + """ For each class, find the set of images from the same class. + + Returns: + -------- + c_relevant_idx = {class_name: [list of image indices]} + """ + assert not isinstance(labels, set), 'labels must be ordered' + + # Get relevants images for each class + c_relevant_idx = defaultdict(list) + for i in range(len(labels)): + c_relevant_idx[labels[i]].append(i) + + return c_relevant_idx + + diff --git a/dirtorch/datasets/landmarks.py b/dirtorch/datasets/landmarks.py new file mode 100644 index 0000000..2c0bbeb --- /dev/null +++ b/dirtorch/datasets/landmarks.py @@ -0,0 +1,19 @@ +import os +from .generic import ImageListLabels + +DB_ROOT = os.environ['DB_ROOT'] + +class Landmarks_clean(ImageListLabels): + def __init__(self): + ImageListLabels.__init__(self, os.path.join(DB_ROOT, 'landmarks/annotations/annotation_clean_train.txt'), + os.path.join(DB_ROOT, 'landmarks/')) + +class Landmarks_clean_val(ImageListLabels): + def __init__(self): + ImageListLabels.__init__(self, os.path.join(DB_ROOT, 'landmarks/annotations/annotation_clean_val.txt'), + os.path.join(DB_ROOT, 'landmarks/')) + +class Landmarks_lite(ImageListLabels): + def __init__(self): + ImageListLabels.__init__(self, os.path.join(DB_ROOT, 'landmarks/annotations/extra_landmark_images.txt'), + os.path.join(DB_ROOT, 'landmarks/')) diff --git a/dirtorch/datasets/landmarks18.py b/dirtorch/datasets/landmarks18.py new file mode 100644 index 0000000..5007f11 --- /dev/null +++ b/dirtorch/datasets/landmarks18.py @@ -0,0 +1,66 @@ +import os +from .generic import ImageListLabels, ImageList + +DB_ROOT = os.environ['DB_ROOT'] + +class Landmarks18_train(ImageListLabels): + def __init__(self): + ImageListLabels.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/train.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + +class Landmarks18(ImageListLabels): + def __init__(self): + ImageListLabels.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/train_all.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + +class Landmarks18_lite(ImageListLabels): + def __init__(self): + ImageListLabels.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/train_lite.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + +class Landmarks18_mid(ImageListLabels): + def __init__(self): + ImageListLabels.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/train_mid.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + +class Landmarks18_5K(ImageListLabels): + def __init__(self): + ImageListLabels.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/train_5K.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + +class Landmarks18_val(ImageListLabels): + def __init__(self): + ImageListLabels.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/val.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + +class Landmarks18_valdstr(ImageListLabels): + def __init__(self): + ImageListLabels.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/val_distractors.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + +class Landmarks18_index(ImageList): + def __init__(self): + ImageList.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/index.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + +class Landmarks18_new_index(ImageList): + def __init__(self): + ImageList.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/index_new.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + +class Landmarks18_test(ImageList): + def __init__(self): + ImageList.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/test.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + +class Landmarks18_pca(ImageList): + def __init__(self): + ImageList.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/train_pca.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + +class Landmarks18_missing_index(ImageList): + def __init__(self): + ImageList.__init__(self, os.path.join(DB_ROOT, 'landmarks18/lists/missing_index.txt'), + os.path.join(DB_ROOT, 'landmarks18/')) + + diff --git a/dirtorch/datasets/oxford.py b/dirtorch/datasets/oxford.py new file mode 100644 index 0000000..e8168cc --- /dev/null +++ b/dirtorch/datasets/oxford.py @@ -0,0 +1,14 @@ +import os +from .generic import ImageListRelevants + +DB_ROOT = os.environ['DB_ROOT'] + +class Oxford5K(ImageListRelevants): + def __init__(self): + ImageListRelevants.__init__(self, os.path.join(DB_ROOT, 'oxford5k/gnd_oxford5k.pkl'), + root=os.path.join(DB_ROOT, 'oxford5k')) + +class ROxford5K(ImageListRelevants): + def __init__(self): + ImageListRelevants.__init__(self, os.path.join(DB_ROOT, 'oxford5k/gnd_roxford5k.pkl'), + root=os.path.join(DB_ROOT, 'oxford5k')) diff --git a/dirtorch/datasets/paris.py b/dirtorch/datasets/paris.py new file mode 100644 index 0000000..02d7ec9 --- /dev/null +++ b/dirtorch/datasets/paris.py @@ -0,0 +1,14 @@ +from .generic import ImageListRelevants +import os + +DB_ROOT = os.environ['DB_ROOT'] + +class Paris6K(ImageListRelevants): + def __init__(self): + ImageListRelevants.__init__(self, os.path.join(DB_ROOT, 'paris6k/gnd_paris6k.pkl'), + root=os.path.join(DB_ROOT, 'paris6k')) + +class RParis6K(ImageListRelevants): + def __init__(self): + ImageListRelevants.__init__(self, os.path.join(DB_ROOT, 'paris6k/gnd_rparis6k.pkl'), + root=os.path.join(DB_ROOT, 'paris6k')) diff --git a/dirtorch/extract_features.py b/dirtorch/extract_features.py new file mode 100644 index 0000000..521ef0a --- /dev/null +++ b/dirtorch/extract_features.py @@ -0,0 +1,188 @@ +import sys +import os; os.umask(7) # group permisions but that's all +import os.path as osp +import pdb + +import json +import tqdm +import numpy as np +import torch +import torch.nn.functional as F + +from dirtorch.utils.convenient import mkdir +from dirtorch.utils import common +from dirtorch.utils.pytorch_loader import get_loader + +import dirtorch.test_dir as test +import dirtorch.nets as nets +import dirtorch.datasets as datasets + +import pickle as pkl +import hashlib + +def hash(x): + m = hashlib.md5() + m.update(str(x).encode('utf-8')) + return m.hexdigest() + +def typename(x): + return type(x).__module__ + +def tonumpy(x): + if typename(x) == torch.__name__: + return x.cpu().numpy() + else: + return x + + +def pool(x, pooling='mean', gemp=3): + if len(x) == 1: return x[0] + x = torch.stack(x, dim=0) + if pooling == 'mean': + return torch.mean(x, dim=0) + elif pooling == 'gem': + def sympow(x, p, eps=1e-6): + s = torch.sign(x) + return (x*s).clamp(min=eps).pow(p) * s + x = sympow(x,gemp) + x = torch.mean(x, dim=0) + return sympow(x, 1/gemp) + else: + raise ValueError("Bad pooling mode: "+str(pooling)) + + +def extract_features(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=None, + threads=8, batch_size=16, output=None, dbg=()): + """ Extract features from trained model (network) on a given dataset. + """ + print("\n>> Extracting features...") + try: + query_db = db.get_query_db() + except NotImplementedError: + query_db = None + + # extract DB feats + bdescs = [] + qdescs = [] + + trfs_list = [trfs] if isinstance(trfs, str) else trfs + + for trfs in trfs_list: + kw = dict(iscuda=net.iscuda, threads=threads, batch_size=batch_size, same_size='Pad' in trfs or 'Crop' in trfs) + bdescs.append( test.extract_image_features(db, trfs, net, desc="DB", **kw) ) + + # extract query feats + if query_db is not None: + qdescs.append( bdescs[-1] if db is query_db else test.extract_image_features(query_db, trfs, net, desc="query", **kw) ) + + # pool from multiple transforms (scales) + bdescs = tonumpy(F.normalize(pool(bdescs, pooling, gemp), p=2, dim=1)) + if query_db is not None: + qdescs = tonumpy(F.normalize(pool(qdescs, pooling, gemp), p=2, dim=1)) + + if whiten is not None: + bdescs = common.whiten_features(bdescs, net.pca, **whiten) + if query_db is not None: + qdescs = common.whiten_features(qdescs, net.pca, **whiten) + + mkdir(output, isfile=True) + if query_db is db or query_db is None: + np.save(output, bdescs) + else: + o = osp.splitext(output) + np.save(o[0]+'.qdescs'+o[1], qdescs) + np.save(o[0]+'.dbdescs'+o[1], bdescs) + print('Features extracted.') + + +def load_model( path, iscuda, whiten=None ): + checkpoint = common.load_checkpoint(path, iscuda) + net = nets.create_model(pretrained="", **checkpoint['model_options']) + net = common.switch_model_to_cuda(net, iscuda, checkpoint) + net.load_state_dict(checkpoint['state_dict']) + net.preprocess = checkpoint.get('preprocess', net.preprocess) + if whiten is not None and 'pca' in checkpoint: + if whiten in checkpoint['pca']: + net.pca = checkpoint['pca'][whiten] + return net + + +def learn_whiten( dataset, net, trfs='', pooling='mean', threads=8, batch_size=16): + descs = [] + trfs_list = [trfs] if isinstance(trfs, str) else trfs + for trfs in trfs_list: + kw = dict(iscuda=net.iscuda, threads=threads, batch_size=batch_size, same_size='Pad' in trfs or 'Crop' in trfs) + descs.append( extract_image_features(dataset, trfs, net, desc="PCA", **kw) ) + # pool from multiple transforms (scales) + descs = F.normalize(pool(descs, pooling), p=2, dim=1) + # learn pca with whiten + pca = common.learn_pca(descs.cpu().numpy(), whiten=True) + return pca + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Evaluate a model') + + parser.add_argument('--dataset', '-d', type=str, required=True, help='Command to load dataset') + parser.add_argument('--checkpoint', type=str, required=True, help='path to weights') + + parser.add_argument('--trfs', type=str, required=False, default='', nargs='+', help='test transforms (can be several)') + parser.add_argument('--pooling', type=str, default="gem", help='pooling scheme if several trf chains') + parser.add_argument('--gemp', type=int, default=3, help='GeM pooling power') + parser.add_argument('--center-bias', type=float, default=0, help='enforce some center bias') + + parser.add_argument('--out-json', type=str, default="", help='path to output json') + parser.add_argument('--detailed', action='store_true', help='return detailed evaluation') + parser.add_argument('--output', type=str, default="", help='path to output features') + + parser.add_argument('--threads', type=int, default=8, help='number of thread workers') + parser.add_argument('--gpu', type=int, nargs='+', help='GPU ids') + parser.add_argument('--dbg', default=(), nargs='*', help='debugging options') + # post-processing + parser.add_argument('--whiten', type=str, default=None, help='applies whitening') + + parser.add_argument('--whitenp', type=float, default=0.5, help='whitening power, default is 0.5 (i.e., the sqrt)') + parser.add_argument('--whitenv', type=int, default=None, help='number of components, default is None (i.e. all components)') + parser.add_argument('--whitenm', type=float, default=1.0, help='whitening multiplier, default is 1.0 (i.e. no multiplication)') + + args = parser.parse_args() + args.iscuda = common.torch_set_gpu(args.gpu) + + dataset = datasets.create(args.dataset) + print("Dataset:", dataset) + + net = load_model(args.checkpoint, args.iscuda, args.whiten) + + if args.center_bias: + assert hasattr(net,'center_bias') + net.center_bias = args.center_bias + if hasattr(net, 'module') and hasattr(net.module,'center_bias'): + net.module.center_bias = args.center_bias + + if args.whiten and not hasattr(net, 'pca'): + # Learn PCA if necessary + if os.path.exists(args.whiten): + with open(args.whiten, 'rb') as f: + net.pca = pkl.load(f) + else: + pca_path = '_'.join([args.checkpoint, args.whiten, args.pooling, hash(args.trfs), 'pca.pkl']) + db = datasets.create(args.whiten) + print('Dataset for learning the PCA with whitening:', db) + pca = learn_whiten(db, net, pooling=args.pooling, trfs=args.trfs, threads=args.threads) + + chk = torch.load(args.checkpoint, map_location=lambda storage, loc: storage) + if 'pca' not in chk: chk['pca'] = {} + chk['pca'][args.whiten] = pca + torch.save(chk, args.checkpoint) + + net.pca = pca + + if args.whiten: + args.whiten = {'whitenp': args.whitenp, 'whitenv': args.whitenv, 'whitenm': args.whitenm} + + # Evaluate + res = extract_features(dataset, net, args.trfs, pooling=args.pooling, gemp=args.gemp, detailed=args.detailed, + threads=args.threads, dbg=args.dbg, whiten=args.whiten, output=args.output) + + diff --git a/dirtorch/nets/__init__.py b/dirtorch/nets/__init__.py new file mode 100644 index 0000000..b49ee59 --- /dev/null +++ b/dirtorch/nets/__init__.py @@ -0,0 +1,127 @@ +''' List all architectures at the bottom of this file. + +To list all available architectures, use: + python -m nets +''' +import os +import pdb +import torch +from collections import OrderedDict + + +def list_archs(): + model_names = {name for name in globals() + if name.islower() and not name.startswith("__") + and name not in internal_funcs + and callable(globals()[name])} + return model_names + + +def create_model(arch, pretrained='', delete_fc=False, *args, **kwargs): + ''' Create an empty network for RMAC. + + arch : str + name of the function to call + + kargs : list + mandatory arguments + + kwargs : dict + optional arguments + ''' + # creating model + if arch not in globals(): + raise NameError("unknown model architecture '%s'\nSelect one in %s" % ( + arch, ','.join(list_archs()))) + model = globals()[arch](*args, **kwargs) + + model.preprocess = dict( + mean = model.rgb_means, + std = model.rgb_stds, + input_size = max(model.input_size) ) + + if os.path.isfile(pretrained or ''): + class watcher: + class AverageMeter: pass + class Watch: pass + import sys + sys.modules['utils.watcher'] = watcher + weights = torch.load(pretrained, map_location=lambda storage, loc: storage)['state_dict'] + load_pretrained_weights(model, weights, delete_fc=delete_fc) + + elif pretrained: + assert hasattr(model, 'load_pretrained_weights'), 'Model %s must be initialized with a valid model file (not %s)' % (arch, pretrained) + model.load_pretrained_weights(pretrained) + + return model + + +def load_pretrained_weights(net, state_dict, delete_fc=False): + """ Load the pretrained weights (chop the last FC layer if needed) + If layers are missing or of wrong shape, will not load them. + """ + + new_dict = OrderedDict() + for k,v in list(state_dict.items()): + if k.startswith('module.'): k = k.replace('module.', '') + new_dict[k]=v + + # Add missing weights from the network itself + d = net.state_dict() + for k,v in list(d.items()): + if k not in new_dict: + if not k.endswith('num_batches_tracked'): + print("Loading weights for %s: Missing layer %s" % (type(net).__name__,k)) + new_dict[k] = v + elif v.shape != new_dict[k].shape: + print("Loading weights for %s: Bad shape for layer %s, skipping" % (type(net).__name__,k)) + new_dict[k] = v + + net.load_state_dict(new_dict) + + # Remove the FC layer if size doesn't match + if delete_fc: + fc = net.fc_name + del new_dict[fc+'.weight'] + del new_dict[fc+'.bias'] + + +""" Import every network HERE +""" +internal_funcs = set(globals().keys()) +from .backbones.resnet import resnet101, resnet50, resnet18, resnet152 +from .rmac_resnet import resnet18_rmac, resnet50_rmac, resnet101_rmac, resnet152_rmac +from .rmac_senet import senet154_rmac, se_resnet50_rmac, se_resnet101_rmac, se_resnet152_rmac, se_resnext50_32x4d_rmac, se_resnext101_32x4d_rmac +from .rmac_resnet_ms import resnet18_rmac_ms, resnet50_rmac_ms, resnet101_rmac_ms, resnet152_rmac_ms +from .rmac_inceptionresnetv2 import inceptionresnetv2_rmac +from .rmac_resnet_fpn import resnet18_fpn_rmac, resnet50_fpn_rmac, resnet101_fpn_rmac, resnet101_fpn0_rmac, resnet152_fpn_rmac + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dirtorch/nets/__main__.py b/dirtorch/nets/__main__.py new file mode 100644 index 0000000..0787cdc --- /dev/null +++ b/dirtorch/nets/__main__.py @@ -0,0 +1,5 @@ +from . import list_archs + +# python -m nets +print("Listing available architectures:") +print("\t" + "\n\t".join(list_archs())) diff --git a/dirtorch/nets/backbones/__init__.py b/dirtorch/nets/backbones/__init__.py new file mode 100644 index 0000000..4f93b12 --- /dev/null +++ b/dirtorch/nets/backbones/__init__.py @@ -0,0 +1,24 @@ +from collections import OrderedDict + + +def load_pretrained_weights(net, state_dict): + """ Load the pretrained weights. + If layers are missing or of wrong shape, will not load them. + """ + new_dict = OrderedDict() + for k,v in list(state_dict.items()): + if k.startswith('module.'): k = k.replace('module.', '') + new_dict[k]=v + + # Add missing weights from the network itself + d = net.state_dict() + for k,v in list(d.items()): + if k not in new_dict: + if not k.endswith('num_batches_tracked'): + print("Loading weights for %s: Missing layer %s" % (type(net).__name__,k)) + new_dict[k] = v + elif v.shape != new_dict[k].shape: + print("Loading weights for %s: Bad shape for layer %s, skipping" % (type(net).__name__,k)) + new_dict[k] = v + + net.load_state_dict(new_dict) diff --git a/dirtorch/nets/backbones/inceptionresnetv2.py b/dirtorch/nets/backbones/inceptionresnetv2.py new file mode 100644 index 0000000..2420cec --- /dev/null +++ b/dirtorch/nets/backbones/inceptionresnetv2.py @@ -0,0 +1,354 @@ +from __future__ import print_function, division, absolute_import +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import os +import sys +from torch.autograd import Variable +import torch.nn.functional as F + + +class BasicConv2d(nn.Module): + + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_planes, out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, bias=False) # verify bias false + self.bn = nn.BatchNorm2d(out_planes, + eps=0.001, # value found in tensorflow + momentum=0.1, # default pytorch value + affine=True) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_5b(nn.Module): + + def __init__(self): + super(Mixed_5b, self).__init__() + + self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(192, 48, kernel_size=1, stride=1), + BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(192, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(192, 64, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block35(nn.Module): + + def __init__(self, scale=1.0): + super(Block35, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), + BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) + ) + + self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_6a(nn.Module): + + def __init__(self): + super(Mixed_6a, self).__init__() + + self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Block17(nn.Module): + + def __init__(self, scale=1.0): + super(Block17, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 128, kernel_size=1, stride=1), + BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)), + BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0)) + ) + + self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_7a(nn.Module): + + def __init__(self): + super(Mixed_7a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), + BasicConv2d(288, 320, kernel_size=3, stride=2) + ) + + self.branch3 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block8(nn.Module): + + def __init__(self, scale=1.0, noReLU=False): + super(Block8, self).__init__() + + self.scale = scale + self.noReLU = noReLU + + self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(2080, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)), + BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0)) + ) + + self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) + if not self.noReLU: + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + if not self.noReLU: + out = self.relu(out) + return out + + +class InceptionResNetV2(nn.Module): + + def __init__(self, fc_out): + nn.Module.__init__(self) + self.fc_out = fc_out + self.fc_in = 1536 + # Special attributes + self.input_space = None + self.rgb_means = [0.5, 0.5, 0.5] + self.rgb_stds = [0.5, 0.5, 0.5] + self.input_size = (3, 229, 229) + # Modules + self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) + self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.maxpool_3a = nn.MaxPool2d(3, stride=2) + self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) + self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) + self.maxpool_5a = nn.MaxPool2d(3, stride=2) + self.mixed_5b = Mixed_5b() + self.repeat = nn.Sequential( + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17) + ) + self.mixed_6a = Mixed_6a() + self.repeat_1 = nn.Sequential( + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10) + ) + self.mixed_7a = Mixed_7a() + self.repeat_2 = nn.Sequential( + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20) + ) + self.block8 = Block8(noReLU=True) + self.conv2d_7b = BasicConv2d(2080, self.fc_in, kernel_size=1, stride=1) + self.last_linear = None + if self.fc_out > 0: + self.avgpool_1a = nn.AvgPool2d(8, count_include_pad=False) + self.last_linear = nn.Linear(self.fc_in, self.fc_out) + self.fc_name = 'last_linear' + + def features(self, input): + x = self.conv2d_1a(input) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + x = self.mixed_6a(x) + x = self.repeat_1(x) + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) + return x + + def logits(self, features): + x = self.avgpool_1a(features) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + if self.fc_out > 0: + x = self.logits(x) + return x + + def load_pretrained_weights(self, pretrain_code): + if pretrain_code == 'imagenet': + url = 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth' + else: + raise NameError("unknown pretraining code '%s'" % pretrain_code) + + print("Loading ImageNet pretrained weights for %s" % pretrain_code) + + model_dir='dirtorch/data/models/classification/' + import os, stat # give group permission + try: os.makedirs(model_dir) + except OSError: pass + + import torch.utils.model_zoo as model_zoo + state_dict = model_zoo.load_url(url, model_dir=model_dir) + + from . import load_pretrained_weights + load_pretrained_weights(self, state_dict) + + +''' +TEST +Run this code with: +``` +cd $HOME/pretrained-models.pytorch +python -m pretrainedmodels.inceptionresnetv2 +``` +''' +if __name__ == '__main__': + + assert inceptionresnetv2(num_classes=10, pretrained=None) + print('success') + assert inceptionresnetv2(num_classes=1000, pretrained='imagenet') + print('success') + assert inceptionresnetv2(num_classes=1001, pretrained='imagenet+background') + print('success') + + # fail + assert inceptionresnetv2(num_classes=1001, pretrained='imagenet') diff --git a/dirtorch/nets/backbones/resnet.py b/dirtorch/nets/backbones/resnet.py new file mode 100644 index 0000000..90e4f23 --- /dev/null +++ b/dirtorch/nets/backbones/resnet.py @@ -0,0 +1,227 @@ +import torch.nn as nn +import torch +import math +import numpy as np +from torch.autograd import Variable +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + +class Bottleneck(nn.Module): + ''' Standard bottleneck block + input = inplanes * H * W + middle = planes * H/stride * W/stride + output = 4*planes * H/stride * W/stride + ''' + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + + + +def reset_weights(net): + for m in net.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class ResNet(nn.Module): + """ A standard ResNet. + """ + def __init__(self, block, layers, fc_out, model_name, self_similarity_radius=None, self_similarity_version=2): + nn.Module.__init__(self) + self.model_name = model_name + + # default values for a network pre-trained on imagenet + self.rgb_means = [0.485, 0.456, 0.406] + self.rgb_stds = [0.229, 0.224, 0.225] + self.input_size = (3, 224, 224) + + self.inplanes = 64 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], self_similarity_radius=self_similarity_radius, self_similarity_version=self_similarity_version) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, self_similarity_radius=self_similarity_radius, self_similarity_version=self_similarity_version) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, self_similarity_radius=self_similarity_radius, self_similarity_version=self_similarity_version) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, self_similarity_radius=self_similarity_radius, self_similarity_version=self_similarity_version) + + reset_weights(self) + + self.fc = None + self.fc_out = fc_out + if self.fc_out > 0: + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(512 * block.expansion, fc_out) + self.fc_name = 'fc' + + def _make_layer(self, block, planes, blocks, stride=1, self_similarity_radius=None, self_similarity_version=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes=planes, stride=stride, downsample=downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + if self_similarity_radius: + if self_similarity_version == 1: + from . self_sim import SelfSimilarity1 + layers.append(SelfSimilarity1(self_similarity_radius, self.inplanes)) + else: + from . self_sim import SelfSimilarity2 + layers.append(SelfSimilarity2(self_similarity_radius, self.inplanes)) + return nn.Sequential(*layers) + + def forward(self, x, out_layer=0): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + if out_layer==-1: + return x, self.layer4(x) + x = self.layer4(x) + + if self.fc_out > 0: + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + def load_pretrained_weights(self, pretrain_code): + if pretrain_code == 'imagenet': + model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + } + else: + raise NameError("unknown pretraining code '%s'" % pretrain_code) + + print("Loading ImageNet pretrained weights for %s" % pretrain_code) + assert self.model_name in model_urls, "Unknown model '%s'" % self.model_name + + model_dir='dirtorch/data/models/classification/' + import os, stat # give group permission + try: os.makedirs(model_dir) + except OSError: pass + + import torch.utils.model_zoo as model_zoo + state_dict = model_zoo.load_url(model_urls[self.model_name], model_dir=model_dir) + + from . import load_pretrained_weights + load_pretrained_weights(self, state_dict) + + + + + +def resnet18(out_dim=2048): + """Constructs a ResNet-18 model. + """ + net = ResNet(BasicBlock, [2, 2, 2, 2], out_dim, 'resnet18') + return net + +def resnet50(out_dim=2048): + """Constructs a ResNet-50 model. + """ + net = ResNet(Bottleneck, [3, 4, 6, 3], out_dim, 'resnet50') + return net + +def resnet101(out_dim=2048): + """Constructs a ResNet-101 model. + """ + net = ResNet(Bottleneck, [3, 4, 23, 3], out_dim, 'resnet101') + return net + +def resnet152(out_dim=2048): + """Constructs a ResNet-152 model. + """ + net = ResNet(Bottleneck, [3, 8, 36, 3], out_dim, 'resnet152') + return net diff --git a/dirtorch/nets/backbones/resnext101_features.py b/dirtorch/nets/backbones/resnext101_features.py new file mode 100644 index 0000000..0ee5ee9 --- /dev/null +++ b/dirtorch/nets/backbones/resnext101_features.py @@ -0,0 +1,1338 @@ +from __future__ import print_function, division, absolute_import +import torch +import torch.nn as nn +from torch.autograd import Variable +from functools import reduce + +class LambdaBase(nn.Sequential): + def __init__(self, fn, *args): + super(LambdaBase, self).__init__(*args) + self.lambda_func = fn + + def forward_prepare(self, input): + output = [] + for module in self._modules.values(): + output.append(module(input)) + return output if output else input + +class Lambda(LambdaBase): + def forward(self, input): + return self.lambda_func(self.forward_prepare(input)) + +class LambdaMap(LambdaBase): + def forward(self, input): + return list(map(self.lambda_func,self.forward_prepare(input))) + +class LambdaReduce(LambdaBase): + def forward(self, input): + return reduce(self.lambda_func,self.forward_prepare(input)) + +resnext101_32x4d_features = nn.Sequential( # Sequential, + nn.Conv2d(3,64,(7, 7),(2, 2),(3, 3),1,1,bias=False), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d((3, 3),(2, 2),(1, 1)), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(64,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + ), + nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + ), + nn.Sequential( # Sequential, + nn.Conv2d(64,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + ), + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(256,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + ), + nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(256,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + ), + nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(256,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256,256,(3, 3),(2, 2),(1, 1),1,32,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + ), + nn.Sequential( # Sequential, + nn.Conv2d(256,512,(1, 1),(2, 2),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + ), + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(2, 2),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + nn.Sequential( # Sequential, + nn.Conv2d(512,1024,(1, 1),(2, 2),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024,1024,(3, 3),(2, 2),(1, 1),1,32,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(2048), + ), + nn.Sequential( # Sequential, + nn.Conv2d(1024,2048,(1, 1),(2, 2),(0, 0),1,1,bias=False), + nn.BatchNorm2d(2048), + ), + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(2048,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(2048), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(2048,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(2048), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + ) +) + + + +resnext101_64x4d_features = nn.Sequential(#Sequential, + nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias = False), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d((3, 3), (2, 2), (1, 1)), + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + ), + nn.Sequential(#Sequential, + nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + ), + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + ), + nn.Sequential(#Sequential, + nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + ), + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + nn.Sequential(#Sequential, + nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + nn.Conv2d(2048, 2048, (3, 3), (2, 2), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + ), + nn.Conv2d(2048, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + ), + nn.Sequential(#Sequential, + nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + ), + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(2048, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + nn.Conv2d(2048, 2048, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + ), + nn.Conv2d(2048, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(2048, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + nn.Conv2d(2048, 2048, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + ), + nn.Conv2d(2048, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + ) +) diff --git a/dirtorch/nets/backbones/senet.py b/dirtorch/nets/backbones/senet.py new file mode 100644 index 0000000..61bd924 --- /dev/null +++ b/dirtorch/nets/backbones/senet.py @@ -0,0 +1,384 @@ +""" +ResNet code gently borrowed from +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +""" +from __future__ import print_function, division, absolute_import +from collections import OrderedDict +import math + +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import os +import sys +from torch.autograd import Variable +import torch.nn.functional as F + + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, + padding=0) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, + padding=0) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class Bottleneck(nn.Module): + """ + Base class for bottlenecks that implements `forward()` method. + """ + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_module(out) + residual + out = self.relu(out) + + return out + + +class SEBottleneck(Bottleneck): + """ + Bottleneck for SENet154. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes * 2) + self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, + stride=stride, padding=1, groups=groups, + bias=False) + self.bn2 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBottleneck(Bottleneck): + """ + ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe + implementation and uses `stride=stride` in `conv1` and not in `conv2` + (the latter is used in the torchvision implementation of ResNet). + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEResNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, + stride=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, + groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNeXtBottleneck(Bottleneck): + """ + ResNeXt bottleneck type C with a Squeeze-and-Excitation module. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None, base_width=4): + super(SEResNeXtBottleneck, self).__init__() + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, + stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SENet(nn.Module): + + def __init__(self, block, layers, groups, reduction, fc_out, model_name, dropout_p=0.2, + inplanes=128, input_3x3=True, downsample_kernel_size=3, + downsample_padding=1): + """ + Parameters + ---------- + block (nn.Module): Bottleneck class. + - For SENet154: SEBottleneck + - For SE-ResNet models: SEResNetBottleneck + - For SE-ResNeXt models: SEResNeXtBottleneck + layers (list of ints): Number of residual blocks for 4 layers of the + network (layer1...layer4). + groups (int): Number of groups for the 3x3 convolution in each + bottleneck block. + - For SENet154: 64 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 32 + reduction (int): Reduction ratio for Squeeze-and-Excitation modules. + - For all models: 16 + dropout_p (float or None): Drop probability for the Dropout layer. + If `None` the Dropout layer is not used. + - For SENet154: 0.2 + - For SE-ResNet models: None + - For SE-ResNeXt models: None + inplanes (int): Number of input channels for layer1. + - For SENet154: 128 + - For SE-ResNet models: 64 + - For SE-ResNeXt models: 64 + input_3x3 (bool): If `True`, use three 3x3 convolutions instead of + a single 7x7 convolution in layer0. + - For SENet154: True + - For SE-ResNet models: False + - For SE-ResNeXt models: False + downsample_kernel_size (int): Kernel size for downsampling convolutions + in layer2, layer3 and layer4. + - For SENet154: 3 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 1 + downsample_padding (int): Padding for downsampling convolutions in + layer2, layer3 and layer4. + - For SENet154: 1 + - For SE-ResNet models: 0 + - For SE-ResNeXt models: 0 + num_classes (int): Number of outputs in `last_linear` layer. + - For all models: 1000 + """ + super(SENet, self).__init__() + self.model_name = model_name + self.inplanes = inplanes + # default values for a network pre-trained on imagenet + self.rgb_means = [0.485, 0.456, 0.406] + self.rgb_stds = [0.229, 0.224, 0.225] + self.input_size = (3, 224, 224) + + if input_3x3: + layer0_modules = [ + ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, + bias=False)), + ('bn1', nn.BatchNorm2d(64)), + ('relu1', nn.ReLU(inplace=True)), + ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, + bias=False)), + ('bn2', nn.BatchNorm2d(64)), + ('relu2', nn.ReLU(inplace=True)), + ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, + bias=False)), + ('bn3', nn.BatchNorm2d(inplanes)), + ('relu3', nn.ReLU(inplace=True)), + ] + else: + layer0_modules = [ + ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, + padding=3, bias=False)), + ('bn1', nn.BatchNorm2d(inplanes)), + ('relu1', nn.ReLU(inplace=True)), + ] + # To preserve compatibility with Caffe weights `ceil_mode=True` + # is used instead of `padding=1`. + layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, + ceil_mode=True))) + self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) + self.layer1 = self._make_layer( + block, + planes=64, + blocks=layers[0], + groups=groups, + reduction=reduction, + downsample_kernel_size=1, + downsample_padding=0 + ) + self.layer2 = self._make_layer( + block, + planes=128, + blocks=layers[1], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer3 = self._make_layer( + block, + planes=256, + blocks=layers[2], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer4 = self._make_layer( + block, + planes=512, + blocks=layers[3], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.last_linear = None + self.fc_out = fc_out + if self.fc_out > 0: + self.avg_pool = nn.AvgPool2d(7, stride=1) + self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None + self.last_linear = nn.Linear(512 * block.expansion, num_classes) + self.fc_name = 'last_linear' + + def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, + downsample_kernel_size=1, downsample_padding=0): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=downsample_kernel_size, stride=stride, + padding=downsample_padding, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, groups, reduction, stride, + downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups, reduction)) + + return nn.Sequential(*layers) + + def features(self, x): + x = self.layer0(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def logits(self, x): + x = self.avg_pool(x) + if self.dropout is not None: + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, x): + x = self.features(x) + if self.fc_out > 0: + x = self.logits(x) + return x + + def load_pretrained_weights(self, pretrain_code): + if pretrain_code == 'imagenet': + model_urls = { + 'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', + 'se_resnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', + 'se_resnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', + 'se_resnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', + 'se_resnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', + 'se_resnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', + } + else: + raise NameError("unknown pretraining code '%s'" % pretrain_code) + + print("Loading ImageNet pretrained weights for %s" % pretrain_code) + assert self.model_name in model_urls, "Unknown model '%s'" % self.model_name + + model_dir='dirtorch/data/models/classification/' + import os, stat # give group permission + try: os.makedirs(model_dir) + except OSError: pass + + import torch.utils.model_zoo as model_zoo + state_dict = model_zoo.load_url(model_urls[self.model_name], model_dir=model_dir) + + from . import load_pretrained_weights + load_pretrained_weights(self, state_dict) + + + +def senet154(out_dim=2048): + model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, + dropout_p=0.2, fc_out=out_dim, model_name='senet154') + return model + + +def se_resnet50(out_dim=2048): + model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + fc_out=out_dim, model_name='se_resnet50') + return model + + +def se_resnet101(out_dim=2048): + model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + fc_out=out_dim, model_name='se_resnet101') + return model + + +def se_resnet152(out_dim=2048): + model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + fc_out=out_dim, model_name='se_resnet152') + return model + + +def se_resnext50_32x4d(out_dim=2048): + model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + fc_out=out_dim, model_name='se_resnext50_32x4d') + return model + + +def se_resnext101_32x4d(out_dim=2048): + model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + fc_out=out_dim, model_name='se_resnext101_32x4d') + return model diff --git a/dirtorch/nets/layers/pooling.py b/dirtorch/nets/layers/pooling.py new file mode 100644 index 0000000..926c0ac --- /dev/null +++ b/dirtorch/nets/layers/pooling.py @@ -0,0 +1,56 @@ +import pdb +import numpy as np +import torch +from torch.autograd import Variable + +import torch.nn as nn +from torch.nn.modules import Module +from torch.nn.parameter import Parameter +import torch.nn.functional as F +import math + +class GeneralizedMeanPooling(Module): + r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. + + The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` + + - At p = infinity, one gets Max Pooling + - At p = 1, one gets Average Pooling + + The output is of size H x W, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size of the image of the form H x W. + Can be a tuple (H, W) or a single H for a square image H x H + H and W can be either a ``int``, or ``None`` which means the size will + be the same as that of the input. + + """ + + def __init__(self, norm, output_size=1, eps=1e-6): + super(GeneralizedMeanPooling, self).__init__() + assert norm > 0 + self.p = float(norm) + self.output_size = output_size + self.eps = eps + + def forward(self, x): + x = x.clamp(min=self.eps).pow(self.p) + return F.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) + + def __repr__(self): + return self.__class__.__name__ + '(' \ + + str(self.p) + ', ' \ + + 'output_size=' + str(self.output_size) + ')' + + + +class GeneralizedMeanPoolingP(GeneralizedMeanPooling): + """ Same, but norm is trainable + """ + def __init__(self, norm=3, output_size=1, eps=1e-6): + super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) + self.p = Parameter(torch.ones(1) * norm) + + diff --git a/dirtorch/nets/rmac_inceptionresnetv2.py b/dirtorch/nets/rmac_inceptionresnetv2.py new file mode 100644 index 0000000..da830de --- /dev/null +++ b/dirtorch/nets/rmac_inceptionresnetv2.py @@ -0,0 +1,73 @@ +import pdb +from .backbones.inceptionresnetv2 import * +from .layers.pooling import GeneralizedMeanPooling, GeneralizedMeanPoolingP + + +def l2_normalize(x, axis=-1): + x = F.normalize(x, p=2, dim=axis) + return x + + +class InceptionResNetV2_RMAC(InceptionResNetV2): + """ ResNet for RMAC (without ROI pooling) + """ + def __init__(self, out_dim=2048, norm_features=False, + pooling='gem', gemp=3, center_bias=0, + dropout_p=None, without_fc=False): + InceptionResNetV2.__init__(self, 0) + self.norm_features = norm_features + self.without_fc = without_fc + self.pooling = pooling + self.center_bias = center_bias + + if pooling == 'max': + self.adpool = nn.AdaptiveMaxPool2d(output_size=1) + elif pooling == 'avg': + self.adpool = nn.AdaptiveAvgPool2d(output_size=1) + elif pooling.startswith('gem'): + self.adpool = GeneralizedMeanPoolingP(norm=gemp) + + self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None + self.last_linear = nn.Linear(1536, out_dim) + self.fc_name = 'last_linear' + self.feat_dim = out_dim + self.detach = False + + def forward(self, x): + x = InceptionResNetV2.forward(self, x) + + if self.dropout is not None: + x = self.dropout(x) + + if self.detach: + # stop the back-propagation here, if needed + x = Variable(x.detach()) + x = self.id(x) # fake transformation + + if self.center_bias > 0: + b = self.center_bias + bias = 1 + torch.FloatTensor([[[[0,0,0,0],[0,b,b,0],[0,b,b,0],[0,0,0,0]]]]).to(x.device) + bias = torch.nn.functional.interpolate(bias, size=x.shape[-2:], mode='bilinear', align_corners=True) + x = x*bias + + # global pooling + x = self.adpool(x) + + if self.norm_features: + x = l2_normalize(x, axis=1) + + x.squeeze_() + if not self.without_fc: + x = self.last_linear(x) + + x = l2_normalize(x, axis=-1) + return x + + + + +def inceptionresnetv2_rmac(backbone=InceptionResNetV2_RMAC, **kwargs): + kwargs.pop('scales', None) + return backbone(**kwargs) + + diff --git a/dirtorch/nets/rmac_resnet.py b/dirtorch/nets/rmac_resnet.py new file mode 100644 index 0000000..e148e34 --- /dev/null +++ b/dirtorch/nets/rmac_resnet.py @@ -0,0 +1,117 @@ +import pdb +import torch +from .backbones.resnet import * +from .layers.pooling import GeneralizedMeanPooling, GeneralizedMeanPoolingP + + +def l2_normalize(x, axis=-1): + x = F.normalize(x, p=2, dim=axis) + return x + + +class ResNet_RMAC(ResNet): + """ ResNet for RMAC (without ROI pooling) + """ + def __init__(self, block, layers, model_name, out_dim=2048, norm_features=False, + pooling='gem', gemp=3, center_bias=0, + dropout_p=None, without_fc=False, **kwargs): + ResNet.__init__(self, block, layers, 0, model_name, **kwargs) + self.norm_features = norm_features + self.without_fc = without_fc + self.pooling = pooling + self.center_bias = center_bias + + if pooling == 'max': + self.adpool = nn.AdaptiveMaxPool2d(output_size=1) + elif pooling == 'avg': + self.adpool = nn.AdaptiveAvgPool2d(output_size=1) + elif pooling.startswith('gem'): + self.adpool = GeneralizedMeanPoolingP(norm=gemp) + else: + raise ValueError(pooling) + + self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None + self.fc = nn.Linear(512 * block.expansion, out_dim) + self.fc_name = 'fc' + self.feat_dim = out_dim + self.detach = False + + def forward(self, x): + bs, _, H, W = x.shape + + x = ResNet.forward(self, x) + + if self.dropout is not None: + x = self.dropout(x) + + if self.detach: + # stop the back-propagation here, if needed + x = Variable(x.detach()) + x = self.id(x) # fake transformation + + if self.center_bias > 0: + b = self.center_bias + bias = 1 + torch.FloatTensor([[[[0,0,0,0],[0,b,b,0],[0,b,b,0],[0,0,0,0]]]]).to(x.device) + bias = torch.nn.functional.interpolate(bias, size=x.shape[-2:], mode='bilinear', align_corners=True) + x = x*bias + + # global pooling + x = self.adpool(x) + + if self.norm_features: + x = l2_normalize(x, axis=1) + + x.squeeze_() + if not self.without_fc: + x = self.fc(x) + + x = l2_normalize(x, axis=-1) + return x + + + + +def resnet18_rmac(backbone=ResNet_RMAC, **kwargs): + kwargs.pop('scales', None) + return backbone(BasicBlock, [2, 2, 2, 2], 'resnet18', **kwargs) + +def resnet50_rmac(backbone=ResNet_RMAC, **kwargs): + kwargs.pop('scales', None) + return backbone(Bottleneck, [3, 4, 6, 3], 'resnet50', **kwargs) + +def resnet101_rmac(backbone=ResNet_RMAC, **kwargs): + kwargs.pop('scales', None) + return backbone(Bottleneck, [3, 4, 23, 3], 'resnet101', **kwargs) + +def resnet152_rmac(backbone=ResNet_RMAC, **kwargs): + kwargs.pop('scales', None) + return backbone(Bottleneck, [3, 8, 36, 3], 'resnet152', **kwargs) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dirtorch/nets/rmac_resnet_fpn.py b/dirtorch/nets/rmac_resnet_fpn.py new file mode 100644 index 0000000..aea143f --- /dev/null +++ b/dirtorch/nets/rmac_resnet_fpn.py @@ -0,0 +1,142 @@ +import pdb +from .backbones.resnet import * +from .layers.pooling import GeneralizedMeanPooling, GeneralizedMeanPoolingP + + +def l2_normalize(x, axis=-1): + x = F.normalize(x, p=2, dim=axis) + return x + + +class ResNet_RMAC_FPN(ResNet): + """ ResNet for RMAC (without ROI pooling) + """ + def __init__(self, block, layers, model_name, out_dim=None, norm_features=False, + pooling='gem', gemp=3, center_bias=0, mode=1, + dropout_p=None, without_fc=False, **kwargs): + ResNet.__init__(self, block, layers, 0, model_name, **kwargs) + self.norm_features = norm_features + self.without_fc = without_fc + self.pooling = pooling + self.center_bias = center_bias + self.mode = mode + + dim1 = 256 * block.expansion + dim2 = 512 * block.expansion + if out_dim is None: out_dim = dim1 + dim2 + #FPN + if self.mode == 1: + self.conv1x5 = nn.Conv2d(dim2, dim1, kernel_size=1, stride=1, bias=False) + self.conv3c4 = nn.Conv2d(dim1, dim1, kernel_size=3, stride=1, padding=1, bias=False) + self.relu = nn.ReLU(inplace=True) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + + if pooling == 'max': + self.adpool = nn.AdaptiveMaxPool2d(output_size=1) + elif pooling == 'avg': + self.adpool = nn.AdaptiveAvgPool2d(output_size=1) + elif pooling == 'gem': + self.adpoolx5 = GeneralizedMeanPoolingP(norm=gemp) + self.adpoolc4 = GeneralizedMeanPoolingP(norm=gemp) + + self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None + self.fc = nn.Linear(768 * block.expansion, out_dim) + self.fc_name = 'fc' + self.feat_dim = out_dim + self.detach = False + + def forward(self, x): + x4, x5 = ResNet.forward(self, x, -1) + + # FPN + if self.mode == 1: + c5 = F.interpolate(x5, size=x4.shape[-2:], mode='nearest') + + c5 = self.conv1x5(c5) + c5 = self.relu(c5) + x4 = x4 + c5 + x4 = self.conv3c4(x4) + x4 = self.relu(x4) + + if self.dropout is not None: + x5 = self.dropout(x5) + x4 = self.dropout(x4) + + if self.detach: + # stop the back-propagation here, if needed + x5 = Variable(x5.detach()) + x5 = self.id(x5) # fake transformation + x4 = Variable(x4.detach()) + x4 = self.id(x4) # fake transformation + + # global pooling + x5 = self.adpoolx5(x5) + x4 = self.adpoolc4(x4) + + x = torch.cat((x4, x5), 1) + + if self.norm_features: + x = l2_normalize(x, axis=1) + + x.squeeze_() + if not self.without_fc: + x = self.fc(x) + + x = l2_normalize(x, axis=-1) + return x + + + + +def resnet18_fpn_rmac(backbone=ResNet_RMAC_FPN, **kwargs): + kwargs.pop('scales', None) + return backbone(BasicBlock, [2, 2, 2, 2], 'resnet18', **kwargs) + +def resnet50_fpn_rmac(backbone=ResNet_RMAC_FPN, **kwargs): + kwargs.pop('scales', None) + return backbone(Bottleneck, [3, 4, 6, 3], 'resnet50', **kwargs) + +def resnet101_fpn_rmac(backbone=ResNet_RMAC_FPN, **kwargs): + kwargs.pop('scales', None) + return backbone(Bottleneck, [3, 4, 23, 3], 'resnet101', **kwargs) + +def resnet101_fpn0_rmac(backbone=ResNet_RMAC_FPN, **kwargs): + kwargs.pop('scales', None) + return backbone(Bottleneck, [3, 4, 23, 3], 'resnet101', mode=0, **kwargs) + +def resnet152_fpn_rmac(backbone=ResNet_RMAC_FPN, **kwargs): + kwargs.pop('scales', None) + return backbone(Bottleneck, [3, 8, 36, 3], 'resnet152', **kwargs) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dirtorch/nets/rmac_resnet_ms.py b/dirtorch/nets/rmac_resnet_ms.py new file mode 100644 index 0000000..70ec758 --- /dev/null +++ b/dirtorch/nets/rmac_resnet_ms.py @@ -0,0 +1,47 @@ +import pdb +from .rmac_resnet import * + + +class ResNet_RMAC_MultiScale (ResNet_RMAC): + """ ResNet for RMAC (without ROI pooling) + """ + def __init__(self, block, layers, out_dim=2048, scales=[1,0.5], **kwargs): + ResNet_RMAC.__init__(self, block, layers, out_dim, **kwargs) + assert scales[0] == 1 and all(0 0: + b = self.center_bias + bias = 1 + torch.FloatTensor([[[[0,0,0,0],[0,b,b,0],[0,b,b,0],[0,0,0,0]]]]).to(x.device) + bias = torch.nn.functional.interpolate(bias, size=x.shape[-2:], mode='bilinear', align_corners=True) + x = x*bias + + # global pooling + x = self.adpool(x) + + if self.norm_features: + x = l2_normalize(x, axis=1) + + x.squeeze_() + if not self.without_fc: + x = self.fc(x) + + x = l2_normalize(x, axis=-1) + return x + + + + +def resnet18_rmac(backbone=ResNet_RMAC, **kwargs): + kwargs.pop('scales', None) + return backbone(BasicBlock, [2, 2, 2, 2], 'resnet18', **kwargs) + +def resnet50_rmac(backbone=ResNet_RMAC, **kwargs): + kwargs.pop('scales', None) + return backbone(Bottleneck, [3, 4, 6, 3], 'resnet50', **kwargs) + +def resnet101_rmac(backbone=ResNet_RMAC, **kwargs): + kwargs.pop('scales', None) + return backbone(Bottleneck, [3, 4, 23, 3], 'resnet101', **kwargs) + +def resnet152_rmac(backbone=ResNet_RMAC, **kwargs): + kwargs.pop('scales', None) + return backbone(Bottleneck, [3, 8, 36, 3], 'resnet152', **kwargs) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dirtorch/nets/rmac_senet.py b/dirtorch/nets/rmac_senet.py new file mode 100644 index 0000000..9faa8ea --- /dev/null +++ b/dirtorch/nets/rmac_senet.py @@ -0,0 +1,95 @@ +import pdb +from .backbones.senet import * +from .layers.pooling import GeneralizedMeanPooling, GeneralizedMeanPoolingP + + +def l2_normalize(x, axis=-1): + x = F.normalize(x, p=2, dim=axis) + return x + + +class SENet_RMAC(SENet): + """ SENet for RMAC (without ROI pooling) + """ + def __init__(self, block, layers, groups, reduction, model_name, + out_dim=2048, norm_features=False, pooling='gem', gemp=3, center_bias=0, + dropout_p=None, without_fc=False, **kwargs): + SENet.__init__(self, block, layers, groups, reduction, 0, model_name, **kwargs) + self.norm_features = norm_features + self.without_fc = without_fc + self.pooling = pooling + self.center_bias = center_bias + + if pooling == 'max': + self.adpool = nn.AdaptiveMaxPool2d(output_size=1) + elif pooling == 'avg': + self.adpool = nn.AdaptiveAvgPool2d(output_size=1) + elif pooling == 'gem': + self.adpool = GeneralizedMeanPoolingP(norm=gemp) + + self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None + self.last_linear = nn.Linear(512 * block.expansion, out_dim) + self.fc_name = 'last_linear' + self.feat_dim = out_dim + self.detach = False + + def forward(self, x): + x = SENet.forward(self, x) + + if self.dropout is not None: + x = self.dropout(x) + + if self.detach: + # stop the back-propagation here, if needed + x = Variable(x.detach()) + x = self.id(x) # fake transformation + + if self.center_bias > 0: + b = self.center_bias + bias = 1 + torch.FloatTensor([[[[0,0,0,0],[0,b,b,0],[0,b,b,0],[0,0,0,0]]]]).to(x.device) + bias = torch.nn.functional.interpolate(bias, size=x.shape[-2:], mode='bilinear', align_corners=True) + x = x*bias + + # global pooling + x = self.adpool(x) + + if self.norm_features: + x = l2_normalize(x, axis=1) + + x.squeeze_() + if not self.without_fc: + x = self.last_linear(x) + + x = l2_normalize(x, axis=-1) + return x + + + +def senet154_rmac(backbone=SENet_RMAC, **kwargs): + kwargs.pop('scales', None) + return backbone(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, model_name='senet154', **kwargs) + +def se_resnet50_rmac(backbone=SENet_RMAC, **kwargs): + kwargs.pop('scales', None) + kwargs = {'inplanes': 64, 'input_3x3': False, 'downsample_kernel_size': 1, 'downsample_padding': 0, **kwargs} + return backbone(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, model_name='se_resnet50', **kwargs) + +def se_resnet101_rmac(backbone=SENet_RMAC, **kwargs): + kwargs.pop('scales', None) + kwargs = {'inplanes': 64, 'input_3x3': False, 'downsample_kernel_size': 1, 'downsample_padding': 0, **kwargs} + return backbone(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, model_name='se_resnet101', **kwargs) + +def se_resnet152_rmac(backbone=SENet_RMAC, **kwargs): + kwargs.pop('scales', None) + kwargs = {'inplanes': 64, 'input_3x3': False, 'downsample_kernel_size': 1, 'downsample_padding': 0, **kwargs} + return backbone(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, model_name='se_resnet152', **kwargs) + +def se_resnext50_32x4d_rmac(backbone=SENet_RMAC, **kwargs): + kwargs.pop('scales', None) + kwargs = {'inplanes': 64, 'input_3x3': False, 'downsample_kernel_size': 1, 'downsample_padding': 0, **kwargs} + return backbone(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, model_name='se_resnext50_32x4d', **kwargs) + +def se_resnext101_32x4d_rmac(backbone=SENet_RMAC, **kwargs): + kwargs.pop('scales', None) + kwargs = {'inplanes': 64, 'input_3x3': False, 'downsample_kernel_size': 1, 'downsample_padding': 0, **kwargs} + return backbone(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, model_name='se_resnext101_32x4d', **kwargs) diff --git a/dirtorch/test_dir.py b/dirtorch/test_dir.py new file mode 100644 index 0000000..e01ac3f --- /dev/null +++ b/dirtorch/test_dir.py @@ -0,0 +1,338 @@ +import sys +import os; os.umask(7) # group permisions but that's all +import os.path as osp +import pdb + +import json +import tqdm +import numpy as np +import torch +import torch.nn.functional as F + +from dirtorch.utils.convenient import mkdir +from dirtorch.utils import common +from dirtorch.utils.pytorch_loader import get_loader +import dirtorch.nets as nets +import dirtorch.datasets as datasets +import dirtorch.datasets.downloader as dl + +import pickle as pkl +import hashlib + +def hash(x): + m = hashlib.md5() + m.update(str(x).encode('utf-8')) + return m.hexdigest() + +def typename(x): + return type(x).__module__ + +def tonumpy(x): + if typename(x) == torch.__name__: + return x.cpu().numpy() + else: + return x + +def matmul(A, B): + if typename(A) == np.__name__: + B = tonumpy(B) + scores = np.dot(A, B.T) + elif typename(B) == torch.__name__: + scores = torch.matmul(A, B.t()).cpu().numpy() + else: + raise TypeError("matrices must be either numpy or torch type") + return scores + +def expand_descriptors(descs, db=None, alpha=0, k=0): + assert k >= 0 and alpha >= 0, 'k and alpha must be non-negative' + if k == 0: return descs + descs = tonumpy(descs) + n = descs.shape[0] + db_descs = tonumpy(db if db is not None else descs) + + sim = matmul(descs, db_descs) + if db is None: + sim[np.diag_indices(n)] = 0 + + idx = np.argpartition(sim, int(-k), axis=1)[:, int(-k):] + descs_aug = np.zeros_like(descs) + for i in range(n): + new_q = np.vstack([db_descs[j, :] * sim[i,j]**alpha for j in idx[i]]) + new_q = np.vstack([descs[i], new_q]) + new_q = np.mean(new_q, axis=0) + descs_aug[i] = new_q / np.linalg.norm(new_q) + + return descs_aug + +def extract_image_features( dataset, transforms, net, ret_imgs=False, same_size=False, flip=None, + desc="Extract feats...", iscuda=True, threads=8, batch_size=8): + """ Extract image features for a given dataset. + Output is 2-dimensional (B, D) + """ + if not same_size: + batch_size = 1 + old_benchmark = torch.backends.cudnn.benchmark # speed-up cudnn + torch.backends.cudnn.benchmark = False # will speed-up a lot for different image sizes + + loader = get_loader( dataset, trf_chain=transforms, preprocess=net.preprocess, iscuda=iscuda, + output=['img'], batch_size=batch_size, threads=threads, + shuffle=False) # VERY IMPORTANT !!!!!! + + if hasattr(net,'eval'): net.eval() + + tocpu = (lambda x: x.cpu()) if ret_imgs=='cpu' else (lambda x:x) + + img_feats = [] + trf_images = [] + with torch.no_grad(): # important to put it outside! + for inputs in tqdm.tqdm(loader, desc, total=1+(len(dataset)-1)//batch_size): + imgs = inputs[0] + for i in range(len(imgs)): + if flip and flip.pop(0): + imgs[i] = imgs[i].flip(2) # flip this image horizontally! + imgs = common.variables(inputs[:1], net.iscuda)[0] + desc = net(imgs) + if ret_imgs: trf_images.append( tocpu(imgs.detach()) ) + del imgs + del inputs + if len(desc.shape) == 1: desc.unsqueeze_(0) + img_feats.append( desc.detach() ) + + img_feats = torch.cat(img_feats, dim=0) + if len(img_feats.shape) == 1: img_feats.unsqueeze_(0) # atleast_2d + + if not same_size: + torch.backends.cudnn.benchmark = old_benchmark + + if ret_imgs: + if same_size: trf_images = torch.cat(trf_images, dim=0) + return trf_images, img_feats + return img_feats + + +def pool(x, pooling='mean', gemp=3): + if len(x) == 1: return x[0] + x = torch.stack(x, dim=0) + if pooling == 'mean': + return torch.mean(x, dim=0) + elif pooling == 'gem': + def sympow(x, p, eps=1e-6): + s = torch.sign(x) + return (x*s).clamp(min=eps).pow(p) * s + x = sympow(x,gemp) + x = torch.mean(x, dim=0) + return sympow(x, 1/gemp) + else: + raise ValueError("Bad pooling mode: "+str(pooling)) + + +def eval_model(db, net, trfs, pooling='mean', gemp=3, detailed=False, whiten=None, + aqe=None, adba=None, threads=8, batch_size=16, save_feats=None, + load_feats=None, load_distractors=None, dbg=()): + """ Evaluate a trained model (network) on a given dataset. + The dataset is supposed to contain the evaluation code. + """ + print("\n>> Evaluation...") + query_db = db.get_query_db() + + # extract DB feats + bdescs = [] + qdescs = [] + + if not load_feats: + trfs_list = [trfs] if isinstance(trfs, str) else trfs + + for trfs in trfs_list: + kw = dict(iscuda=net.iscuda, threads=threads, batch_size=batch_size, same_size='Pad' in trfs or 'Crop' in trfs) + bdescs.append( extract_image_features(db, trfs, net, desc="DB", **kw) ) + + # extract query feats + qdescs.append( bdescs[-1] if db is query_db else extract_image_features(query_db, trfs, net, desc="query", **kw) ) + + # pool from multiple transforms (scales) + bdescs = F.normalize(pool(bdescs, pooling, gemp), p=2, dim=1) + qdescs = F.normalize(pool(qdescs, pooling, gemp), p=2, dim=1) + else: + bdescs = np.load(os.path.join(load_feats, 'feats.bdescs.npy')) + qdescs = np.load(os.path.join(load_feats, 'feats.qdescs.npy')) + + if save_feats: + mkdir(save_feats, isfile=True) + np.save(save_feats+'.bdescs', bdescs.cpu().numpy()) + if query_db is not db: + np.save(save_feats+'.qdescs', qdescs.cpu().numpy()) + exit() + + if load_distractors: + ddescs = [ np.load(os.path.join(load_distractors, '%d.bdescs.npy' % i)) for i in tqdm.tqdm(range(0,1000), 'Distractors') ] + bdescs = np.concatenate([tonumpy(bdescs)] + ddescs) + qdescs = tonumpy(qdescs) # so matmul below can work + + if whiten is not None: + bdescs = common.whiten_features(tonumpy(bdescs), net.pca, **whiten) + qdescs = common.whiten_features(tonumpy(qdescs), net.pca, **whiten) + + if adba is not None: + bdescs = expand_descriptors(bdescs, **args.adba) + if aqe is not None: + qdescs = expand_descriptors(qdescs, db=bdescs, **args.aqe) + + scores = matmul(qdescs, bdescs) + + del bdescs + del qdescs + + res = {} + + try: + aps = [db.eval_query_AP(q, s) for q,s in enumerate(tqdm.tqdm(scores,desc='AP'))] + if not isinstance(aps[0], dict): + aps = [float(e) for e in aps] + if detailed: res['APs'] = aps + res['mAP'] = float(np.mean([e for e in aps if e>=0])) # Queries with no relevants have an AP of -1 + else: + modes = aps[0].keys() + for mode in modes: + apst = [float(e[mode]) for e in aps] + if detailed: res['APs'+'-'+mode] = apst + res['mAP'+'-'+mode] = float(np.mean([e for e in apst if e>=0])) # Queries with no relevants have an AP of -1 + + if 'ap' in dbg: + pdb.set_trace() + pyplot(globals()) + for query in np.argsort(aps): + subplot_grid(20, 1) + pl.imshow(query_db.get_image(query)) + qlabel = query_db.get_label(query) + pl.xlabel('#%d %s' % (query, qlabel)) + pl_noticks() + ranked = np.argsort(scores[query])[::-1] + gt = db.get_query_groundtruth(query)[ranked] + + for i,idx in enumerate(ranked): + if i+2 > 20: break + subplot_grid(20, i+2) + pl.imshow(db.get_image(idx)) + pl.xlabel('#%d %s %g' % (idx, 'OK' if label==qlabel else 'BAD', scores[query,idx])) + pl_noticks() + pdb.set_trace() + except NotImplementedError: + print(" AP not implemented!") + + try: + tops = [db.eval_query_top(q,s) for q,s in enumerate(tqdm.tqdm(scores,desc='top1'))] + if detailed: res['tops'] = tops + for k in tops[0]: + res['top%d'%k] = float(np.mean([top[k] for top in tops])) + except NotImplementedError: + pass + + return res + + +def load_model( path, iscuda ): + checkpoint = common.load_checkpoint(path, iscuda) + net = nets.create_model(pretrained="", **checkpoint['model_options']) + net = common.switch_model_to_cuda(net, iscuda, checkpoint) + net.load_state_dict(checkpoint['state_dict']) + net.preprocess = checkpoint.get('preprocess', net.preprocess) + if 'pca' in checkpoint: net.pca = checkpoint.get('pca', net.pca) + return net + + +def learn_whiten( dataset, net, trfs='', pooling='mean', threads=8, batch_size=16): + descs = [] + trfs_list = [trfs] if isinstance(trfs, str) else trfs + for trfs in trfs_list: + kw = dict(iscuda=net.iscuda, threads=threads, batch_size=batch_size, same_size='Pad' in trfs or 'Crop' in trfs) + descs.append( extract_image_features(dataset, trfs, net, desc="PCA", **kw) ) + # pool from multiple transforms (scales) + descs = F.normalize(pool(descs, pooling), p=2, dim=1) + # learn pca with whiten + pca = common.learn_pca(descs.cpu().numpy(), whiten=True) + return pca + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Evaluate a model') + + parser.add_argument('--dataset', '-d', type=str, required=True, help='Command to load dataset') + parser.add_argument('--checkpoint', type=str, required=True, help='path to weights') + + parser.add_argument('--trfs', type=str, required=False, default='', nargs='+', help='test transforms (can be several)') + parser.add_argument('--pooling', type=str, default="gem", help='pooling scheme if several trf chains') + parser.add_argument('--gemp', type=int, default=3, help='GeM pooling power') + parser.add_argument('--center-bias', type=float, default=0, help='enforce some center bias') + + parser.add_argument('--out-json', type=str, default="", help='path to output json') + parser.add_argument('--detailed', action='store_true', help='return detailed evaluation') + parser.add_argument('--save-feats', type=str, default="", help='path to output features') + parser.add_argument('--load-distractors', type=str, default="", help='path to load distractors from') + parser.add_argument('--load-feats', type=str, default="", help='path to load features from') + + parser.add_argument('--threads', type=int, default=8, help='number of thread workers') + parser.add_argument('--gpu', type=int, default=0, nargs='+', help='GPU ids') + parser.add_argument('--dbg', default=(), nargs='*', help='debugging options') + # post-processing + parser.add_argument('--whiten', type=str, default='Landmarks_clean', help='applies whitening') + + parser.add_argument('--aqe', type=int, nargs='+', help='alpha-query expansion paramenters') + parser.add_argument('--adba', type=int, nargs='+', help='alpha-database augmentation paramenters') + + parser.add_argument('--whitenp', type=float, default=0.25, help='whitening power, default is 0.5 (i.e., the sqrt)') + parser.add_argument('--whitenv', type=int, default=None, help='number of components, default is None (i.e. all components)') + parser.add_argument('--whitenm', type=float, default=1.0, help='whitening multiplier, default is 1.0 (i.e. no multiplication)') + + args = parser.parse_args() + args.iscuda = common.torch_set_gpu(args.gpu) + if args.aqe is not None: args.aqe = {'k': args.aqe[0], 'alpha': args.aqe[1]} + if args.adba is not None: args.adba = {'k': args.adba[0], 'alpha': args.adba[1]} + + dl.download_dataset(args.dataset) + + dataset = datasets.create(args.dataset) + print("Test dataset:", dataset) + + net = load_model(args.checkpoint, args.iscuda) + + if args.center_bias: + assert hasattr(net,'center_bias') + net.center_bias = args.center_bias + if hasattr(net, 'module') and hasattr(net.module,'center_bias'): + net.module.center_bias = args.center_bias + + if args.whiten and not hasattr(net, 'pca'): + # Learn PCA if necessary + if os.path.exists(args.whiten): + with open(args.whiten, 'rb') as f: + net.pca = pkl.load(f) + else: + pca_path = '_'.join([args.checkpoint, args.whiten, args.pooling, hash(args.trfs), 'pca.pkl']) + db = datasets.create(args.whiten) + print('Dataset for learning the PCA with whitening:', db) + net.pca = learn_whiten(db, net, pooling=args.pooling, trfs=args.trfs, threads=args.threads) + with open(pca_path, 'wb') as f: + pkl.dump(net.pca, f) + + args.whiten = {'whitenp': args.whitenp, 'whitenv': args.whitenv, 'whitenm': args.whitenm} + + # Evaluate + res = eval_model(dataset, net, args.trfs, pooling=args.pooling, gemp=args.gemp, detailed=args.detailed, + threads=args.threads, dbg=args.dbg, whiten=args.whiten, aqe=args.aqe, adba=args.adba, + save_feats=args.save_feats, load_feats=args.load_feats, load_distractors=args.load_distractors) + print(' * '+'\n * '.join(['%s = %g'%p for p in res.items()])) + + if args.out_json: + # write to file + try: + data = json.load(open(args.out_json)) + except IOError: + data = {} + data[args.dataset] = res + mkdir(args.out_json) + open(args.out_json,'w').write(json.dumps(data, indent=1)) + print("saved to "+args.out_json) + + + diff --git a/dirtorch/utils/common.py b/dirtorch/utils/common.py new file mode 100644 index 0000000..3d300f6 --- /dev/null +++ b/dirtorch/utils/common.py @@ -0,0 +1,226 @@ +import os +import sys +import pdb +import shutil +from collections import OrderedDict +import numpy as np +import sklearn.decomposition + +try: + import torch + import torch.nn as nn +except ImportError: + pass + + +def torch_set_gpu(gpus, seed=None, randomize=True): + if type(gpus) is int: + gpus = [gpus] + + assert gpus, 'error: empty gpu list, use --gpu N N ...' + + cuda = all(gpu>=0 for gpu in gpus) + + if cuda: + if any(gpu >= 1000 for gpu in gpus): + visible_gpus = [int(gpu) for gpu in os.environ['CUDA_VISIBLE_DEVICES'].split(',')] + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(visible_gpus[gpu-1000]) for gpu in gpus]) + else: + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in gpus]) + assert cuda and torch.cuda.is_available(), "%s has GPUs %s unavailable" % ( + os.environ['HOSTNAME'],os.environ['CUDA_VISIBLE_DEVICES']) + torch.backends.cudnn.benchmark = True # speed-up cudnn + torch.backends.cudnn.fastest = True # even more speed-up? + print( 'Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES'] ) + else: + print( 'Launching on >> CPU <<' ) + + torch_set_seed(seed, cuda, randomize=randomize) + return cuda + + +def torch_set_seed(seed, cuda, randomize=True): + if seed: + # this makes it 3x SLOWER but deterministic + torch.backends.cudnn.enabled = False + + if randomize and not seed: + import time + seed = int(np.uint32(hash(time.time()))) + + if seed: + np.random.seed(seed) + torch.manual_seed(seed) + if cuda: + torch.cuda.manual_seed(seed) + + +def save_checkpoint(state, is_best, filename): + try: + dirs = os.path.split(filename)[0] + if not os.path.isdir(dirs): os.makedirs(dirs) + torch.save(state, filename) + if is_best: + filenamebest = filename+'.best' + shutil.copyfile(filename, filenamebest) + filename = filenamebest + print( "saving to "+filename ) + except: + print( "Error: Could not save checkpoint at %s, skipping" % filename ) + + +def load_checkpoint(filename, iscuda=False): + if not filename: return None + assert os.path.isfile(filename), "=> no checkpoint found at '%s'" % filename + checkpoint = torch.load(filename, map_location=lambda storage, loc: storage) + print("=> loading checkpoint '%s'" % filename, end='') + for key in ['epoch', 'iter', 'current_iter']: + if key in checkpoint: + print(" (%s %d)" % (key, checkpoint[key]), end='') + print() + + new_dict = OrderedDict() + for k,v in list(checkpoint['state_dict'].items()): + if k.startswith('module.'): + k = k[7:] + new_dict[k] = v + checkpoint['state_dict'] = new_dict + + if iscuda and 'optimizer' in checkpoint: + try: + for state in checkpoint['optimizer']['state'].values(): + for k, v in state.items(): + if iscuda and torch.is_tensor(v): + state[k] = v.cuda() + except RuntimeError as e: + print("RuntimeError:",e,"(machine %s, GPU %s)"%( + os.environ['HOSTNAME'],os.environ['CUDA_VISIBLE_DEVICES']), + file=sys.stderr) + sys.exit(1) # error + + return checkpoint + + +def switch_model_to_cuda(model, iscuda=True, checkpoint=None): + if iscuda: + if checkpoint: + checkpoint['state_dict'] = {'module.'+k:v for k,v in checkpoint['state_dict'].items()} + try: + model = torch.nn.DataParallel(model) + + # copy attributes automatically + for var in dir(model.module): + if var.startswith('_'): continue + val = getattr(model.module,var) + if isinstance(val, (bool, int, float, str, dict)) or \ + (callable(val) and var.startswith('get_')): + setattr(model, var, val) + + model.cuda() + model.isasync = True + except RuntimeError as e: + print("RuntimeError:",e,"(machine %s, GPU %s)"%( + os.environ['HOSTNAME'],os.environ['CUDA_VISIBLE_DEVICES']), + file=sys.stderr) + sys.exit(1) # error + + model.iscuda = iscuda + return model + + +def model_size(model): + ''' Computes the number of parameters of the model + ''' + size = 0 + for weights in model.state_dict().values(): + size += np.prod(weights.shape) + return size + + +def freeze_batch_norm(model, freeze=True, only_running=False): + model.freeze_bn = bool(freeze) + if not freeze: return + + for m in model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() # Eval mode freezes the running mean and std + for param in m.named_parameters(): + if only_running: + param[1].requires_grad = True # Weight and bias can be updated + else: + param[1].requires_grad = False # Freeze the weight and bias + + +def variables(inputs, iscuda, not_on_gpu=[]): + ''' convert several Tensors to cuda.Variables + Tensor whose index are in not_on_gpu stays on cpu. + ''' + inputs_var = [] + + for i,x in enumerate(inputs): + if i not in not_on_gpu and not isinstance(x, (tuple,list)): + if iscuda: x = x.cuda(non_blocking=True) + x = torch.autograd.Variable(x) + inputs_var.append(x) + + return inputs_var + +def learn_pca(X, n_components=None, whiten=False, use_sklearn=True): + ''' Learn Principal Component Analysis + + input: + X: input matrix with size samples x features + n_components: number of components to keep + whiten: applies feature whitening + + output: + PCA: weights and means of the PCA learned + ''' + if use_sklearn: + pca = sklearn.decomposition.PCA(n_components=n_components, svd_solver='full', whiten=whiten) + pca.fit(X) + else: + fudge=1E-8 + means = np.mean(X, axis=0) + X = X - means + + # get the covariance matrix + Xcov = np.dot(X.T,X) + + # eigenvalue decomposition of the covariance matrix + d, V = np.linalg.eigh(Xcov) + d[d<0] = fudge + + # a fudge factor can be used so that eigenvectors associated with + # small eigenvalues do not get overamplified. + D = np.diag(1. / np.sqrt(d+fudge)) + + # whitening matrix + W = np.dot(np.dot(V, D), V.T) + + # multiply by the whitening matrix + X_white = np.dot(X, W) + + pca = {'W': W, 'means': means} + + return pca + +def transform(pca, X, whitenp=0.5, whitenv=None, whitenm=1.0, use_sklearn=True): + if use_sklearn: + # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/decomposition/base.py#L99 + if pca.mean_ is not None: + X = X - pca.mean_ + X_transformed = np.dot(X, pca.components_[:whitenv].T) + if pca.whiten: + X_transformed /= whitenm * np.power(pca.explained_variance_[:whitenv], whitenp) + else: + X = X - pca['means'] + X_transformed = np.dot(X, pca['W']) + return X_transformed + +def whiten_features(X, pca, l2norm=True, whitenp=0.5, whitenv=None, whitenm=1.0, use_sklearn=True): + res = transform(pca, X, whitenp=whitenp, whitenv=whitenv, whitenm=whitenm, use_sklearn=use_sklearn) + if l2norm: + res = res / np.expand_dims(np.linalg.norm(res, axis=1), axis=1) + return res + diff --git a/dirtorch/utils/convenient.py b/dirtorch/utils/convenient.py new file mode 100644 index 0000000..b79c7d2 --- /dev/null +++ b/dirtorch/utils/convenient.py @@ -0,0 +1,188 @@ +import os + +################################################ +# file stuff + +def mkdir(d): + try: os.makedirs(d) + except OSError: pass + + +def mkdir( fname, isfile='auto' ): + ''' Make a directory given a file path + If the path is already a directory, make sure it ends with '/' ! + ''' + if isfile == 'auto': + isfile = bool(os.path.splitext(fname)[1]) + if isfile: + directory = os.path.split(fname)[0] + else: + directory = fname + if directory and not os.path.isdir( directory ): + os.makedirs(directory) +_mkdir = mkdir + + +def touch(filename): + ''' Touch is file. Create the file and directory if necessary. + ''' + assert isinstance(filename, str), 'filename "%s" must be a string' % (str(filename)) + dirs = os.path.split(filename)[0] + mkdir(dirs) + open(filename,'r+' if os.path.isfile(filename) else 'w') # touch + + +def assert_outpath( path, ext='', mkdir=False ): + """ Verify that the output file has correct format. + """ + folder, fname = os.path.split(path) + if ext: assert os.path.splitext(fname)[1] == ext, 'Bad file extension, should be '+ext + if mkdir: _mkdir(folder, isfile=False) + assert os.path.isdir(folder), 'Destination folder not found '+folder + assert not os.path.isfile(path), 'File already exists '+path + + + +################################################ +# Multiprocessing stuff +import multiprocessing as mp +import multiprocessing.dummy + +class _BasePool (object): + def __init__(self, nt=0): + self.n = max(1,min(mp.cpu_count(), nt if nt>0 else nt+mp.cpu_count())) + def starmap(self, func, args): + self.map(lambda a: func(*a), args) + +class ProcessPool (_BasePool): + def __init__(self, nt=0): + CorePool.__init__(self, nt) + self.map = map if self.n==1 else mp.Pool(self.n).map + +class ThreadPool (_BasePool): + def __init__(self, nt=0): + CorePool.__init__(self, nt) + self.map = map if self.n==1 else mp.dummy.Pool(self.n).map + + +################################################ +# List utils + +def is_iterable(val, exclude={str}): + if type(exclude) not in (tuple, dict, set): + exclude = {exclude} + try: + if type(val) in exclude: # iterable but no + raise TypeError() + plouf = iter(val) + return True + except TypeError: + return False + + +def listify( val, exclude={str} ): + # make it iterable + return val if is_iterable(val,exclude=exclude) else (val,) + + +def unlistify( lis ): + # if it contains just one element, returns it + if len(lis) == 1: + for e in lis: return e + return lis + + +################################################ +# file stuff + +def sig_folder_ext(f): + return (os.path.split(f)[0], os.path.splitext(f)[1]) +def sig_folder(f): + return os.path.split(f)[0] +def sig_ext(f): + return os.path.splitext(f)[1] +def sig_3folder_ext(f): + f = f.replace('//','/') + f = f.replace('//','/') + return tuple(f.split('/')[:3]) + (os.path.splitext(f)[1],) +def sig_all(f): + return () + +def saferm(f, sig=sig_folder_ext ): + if not os.path.isfile(f): + return True + if not hasattr(saferm,'signature'): + saferm.signature = set() # init + + if sig(f) not in saferm.signature: + # reset if the signature is different + saferm.ask = True + saferm.signature.add( sig(f) ) + + if saferm.ask: + print('confirm removal of %s ? (y/n/all)' %f, end=' ') + ans = input() + if ans not in ('y','all'): return False + if ans == 'all': saferm.ask = False + + os.remove(f) + return True + + +################################################ +# measuring time + +_tics = dict() +from collections import defaultdict +_tics_cum = defaultdict(float) + +def tic(tag='tic'): + from time import time as now + _tics['__last__'] = tag + _tics[tag] = now() + +def toc(tag='', cum=False): + from time import time as now + t = now() + tag = tag or _tics['__last__'] + t -= _tics[tag] + if cum: + nb, oldt = _tics_cum.get(tag,(0,0)) + nb += 1 + t += oldt + _tics_cum[tag] = nb,t + if cum=='avg': t/=nb + print('%selpased time since %s = %gs' % ({False:'',True:'cumulated ','avg':'average '}[cum], tag,t)) + return t + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dirtorch/utils/evaluation.py b/dirtorch/utils/evaluation.py new file mode 100644 index 0000000..c4db204 --- /dev/null +++ b/dirtorch/utils/evaluation.py @@ -0,0 +1,65 @@ +'''Evaluation metrics +''' +import pdb +import numpy as np +import torch + + +def accuracy_topk(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k + + output: torch.FloatTensoror np.array(float) + shape = B * L [* H * W] + L: number of possible labels + + target: torch.IntTensor or np.array(int) + shape = B [* H * W] + ground-truth labels + """ + if isinstance(output, np.ndarray): + pred = (-output).argsort(axis=1) + target = np.expand_dims(target, axis=1) + correct = (pred == target) + + res = [] + for k in topk: + correct_k = correct[:,:k].sum() + res.append(correct_k / target.size) + + if isinstance(output, torch.Tensor): + _, pred = output.topk(max(topk), 1, True, True) + correct = pred.eq(target.unsqueeze(1)) + + res = [] + for k in topk: + correct_k = correct[:,:k].float().view(-1).sum(0) + res.append(correct_k.mul_(1.0 / target.numel())) + + return res + + + +def compute_AP(label, score): + from sklearn.metrics import average_precision_score + return average_precision_score(label, score) + + +def compute_average_precision_quantized(labels, idx, step=0.01): + recall_checkpoints = np.arange(0, 1, step) + def mymax(x, default): + return np.max(x) if len(x) else default + Nrel = np.sum(labels) + if Nrel == 0: + return 0 + recall = np.cumsum(labels[idx])/float(Nrel) + irange = np.arange(1, len(idx)+1) + prec = np.cumsum(labels[idx]).astype(np.float32) / irange + precs = np.array([mymax(prec[np.where(recall > v)], 0) for v in recall_checkpoints]) + return np.mean(precs) + + + +def pixelwise_iou(output, target): + """ For each image, for each label, compute the IoU between + """ + assert False diff --git a/dirtorch/utils/funcs.py b/dirtorch/utils/funcs.py new file mode 100644 index 0000000..49efcb7 --- /dev/null +++ b/dirtorch/utils/funcs.py @@ -0,0 +1,19 @@ +""" generic functions +""" +import pdb +import numpy as np + + +def sigmoid(x, a=1, b=0): + return 1 / (1 + np.exp(a * (b - x))) + + +def sigmoid_range(x, at5, at95): + """ create sigmoid function like that: + sigmoid(at5) = 0.05 + sigmoid(at95) = 0.95 + and returns sigmoid(x) + """ + a = 6 / (at95 - at5) + b = at5 + 3 / a + return sigmoid(x, a, b) diff --git a/dirtorch/utils/pyplot.py b/dirtorch/utils/pyplot.py new file mode 100644 index 0000000..ea4925f --- /dev/null +++ b/dirtorch/utils/pyplot.py @@ -0,0 +1,179 @@ +''' Just a bunch of pyplot utilities... +''' +import pdb +import numpy as np + + +def pyplot(globs=None, ion=True, backend='TkAgg'): #None): + ''' load pyplot and shit, in interactive mode + ''' + globs = globs or globals() + if 'pl' not in globs: + if backend: + import matplotlib + matplotlib.use(backend) + import matplotlib.pyplot as pl + if ion: pl.ion() + globs['pl'] = pl + + +def figure(name, clf=True, **kwargs): + pyplot() + f = pl.figure(name, **kwargs) + f.canvas.manager.window.attributes('-topmost',1) + pl.subplots_adjust(0,0,1,1,0,0) + if clf: pl.clf() + return f + + +def pl_imshow( img, **kwargs ): + if isinstance(img, str): + from PIL import Image + img = Image.open(img) + pyplot() + pl.imshow(img, **kwargs) + + +def pl_noticks(): + pl.xticks(()) + pl.yticks(()) + +def fig_imshow( figname, img, **kwargs ): + fig = figure(figname) + pl_imshow( img, ** kwargs) + pdb.set_trace() + + +def crop_text(sentence, maxlen=10): + lines = [''] + for word in sentence.split(): + t = lines[-1] + ' ' + word + if len(t) <= maxlen: + lines[-1] = t + else: + lines.append( word ) + if lines[0] == '': lines.pop(0) + return lines + + +def plot_bbox( bbox, fmt='rect', color='blue', text='', text_effects=False, text_on_box=False, scale=True, fill_color=None, ax=None, **kwargs ): + pyplot() + ax = ax or pl.gca() + + if fmt == 'rect' or fmt == 'xyxy': + ''' bbox = (left, top, right, bottom) + ''' + assert len(bbox) == 4, pdb.set_trace() + x0,y0,x1,y1 = bbox + X,Y = [x0,x0,x1,x1,x0], [y0,y1,y1,y0,y0] + + elif fmt == 'box' or fmt == 'xywh': + ''' bbox = (left, top, width, height) + ''' + assert len(bbox) == 4, pdb.set_trace() + x0,y0,w,h = bbox + X,Y = [x0,x0,x0+w,x0+w,x0], [y0,y0+h,y0+h,y0,y0] + + elif fmt == '4pts': + ''' bbox = ((lx,ly), (tx,ty), (rx,ty), (bx,by)) + ''' + assert len(bbox) >= 4, pdb.set_trace() + bbox = np.asarray(bbox) + X, Y = bbox[:,0], bbox[:,1] + X = list(X)+[X[0]] + Y = list(Y)+[Y[0]] + + elif fmt == '8val': + ''' bbox = 8-tuples of (x0,y0, x1,y0, x0,y1, x1,y1) + ''' + assert len(bbox) >= 8, pdb.set_trace() + X, Y = bbox[0::2], bbox[1::2] + + else: + raise ValueError("bad format for a bbox: %s" % fmt) + + ls = kwargs.pop('ls','-') + line = ax.plot( X, Y, ls, scalex=scale, scaley=scale, color=color, **kwargs) + + if fill_color: + ax.fill(X, Y, fill_color, alpha=0.3) + if text: + text = '\n'.join(crop_text(text, 10)) + + color = line[0].get_color() + + if text_on_box: + text = ax.text(bbox[0], bbox[1]-2, text, fontsize=8, color=color) + else: + text = ax.text( np.mean(X), np.mean(Y), text, + ha='center', va='center', fontsize=16, color=color, + clip_on=True) + + if text_effects: + import matplotlib.patheffects as path_effects + effects = [path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()] + text.set_path_effects(effects) + + return line + +def plot_rect( rect, **kwargs): + return plot_bbox( rect, fmt='xyxy', **kwargs ) + +def plot_poly( poly, **kwargs ): + return plot_bbox( poly, fmt='4pts', **kwargs ) + + +def plot_cam_on_map( pose, fmt="xyz,rpy", fov=70.0, cone=10, marker='+', color='r', ax=None): + if not ax: ax = pl.gca() + + if fmt == "xyz,rpy": + x,y = pose[0,0:2] + A = pose[1,2] + fov*np.pi/180/2 + B = pose[1,2] - fov*np.pi/180/2 + elif fmt == "xylr": + x,y,A,B = pose + else: + raise ValueError("Unknown pose format %s" % str(fmt)) + + ax.plot(x,y, marker, color=color) + ax.plot([x,x+cone*np.cos(A)],[y,y+cone*np.sin(A)], '--', color=color) + ax.plot([x,x+cone*np.cos(B)],[y,y+cone*np.sin(B)], '--', color=color) + + +def subplot_grid(nb, n, aspect=1): + """ automatically split into rows and columns + + aspect : float. aspect ratio of the subplots (width / height). + """ + pyplot() + nr = int(np.sqrt(nb * aspect)) + nc = int((nb-1) / nr + 1) + return pl.subplot(nr, nc, n) + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dirtorch/utils/pytorch_loader.py b/dirtorch/utils/pytorch_loader.py new file mode 100644 index 0000000..b6132ff --- /dev/null +++ b/dirtorch/utils/pytorch_loader.py @@ -0,0 +1,302 @@ +import pdb + +from PIL import Image +import numpy as np +import random + +import torch +import torch.utils.data as data + + +def get_loader( dataset, trf_chain, iscuda, + preprocess = {}, # variables for preprocessing (input_size, mean, std, ...) + output = ('img','label'), + batch_size = None, + threads = 1, + shuffle = True, + balanced = 0, use_all = False, + totensor = True, + **_useless_kw): + ''' Get a data loader, given the dataset and some parameters. + + Parameters + ---------- + dataset : Dataset(). + Class containing all images and labels. + + trf_chain : list + list of transforms + + iscuda : bool + + output : tuple of str + tells what to return. 'img', 'label', ... See PytorchLoader(). + + preprocess : dict + {input_size:..., mean=..., std:..., ...} + + batch_size : int + + threads : int + + shuffle : int + + balanced : float in [0,1] + if balanced>0, then will pick dataset samples such that each class is equally represented. + + use_all : bool + if True, will force to use all dataset samples at least once (even if balanced>0) + + Returns + ------- + a pytorch loader. + ''' + from . import transforms + trf_chain = transforms.create(trf_chain, to_tensor=True, **preprocess) + + sampler = None + if balanced: + sampler = BalancedSampler(dataset, use_all=use_all, balanced=balanced) + shuffle = False + + loader = PytorchLoader(dataset, transform=trf_chain, output=output) + + if threads == 1: + return loader + else: + return data.DataLoader( + loader, + batch_size = batch_size, + shuffle = shuffle, + sampler = sampler, + num_workers = threads, + pin_memory = iscuda) + + + + +class PytorchLoader (data.Dataset): + """A pytorch dataset-loader + + Args: + dataset (object): dataset inherited from dataset.Dataset() + + transform (deprecated, callable): pytorch transforms. Use img_and_target_transform instead. + + target_transform (deprecated, callable): applied on target. Use img_and_target_transform instead. + + img_and_target_transform (callable): + applied on dict(img=, label=, bbox=, ...) + and should return a similar dictionary. + + Attributes: + dataset (object): subclass of dataset.Dataset() + """ + + def __init__(self, dataset, transform=None, + target_transform=None, + img_and_target_transform=None, + output=['img','label']): + self.dataset = dataset + self.transform = transform + self.target_transform = target_transform + self.img_and_target_transform = img_and_target_transform + self.output = output + + def __getitem__(self, index): + img_filename = self.dataset.get_filename(index) + + img_and_label = dict( + img_filename = img_filename, + img_key = self.dataset.get_key(index), + img = self.dataset.get_image(index), + label = try_to_get(self.dataset.get_label, index, toint=True) ) + + if self.img_and_target_transform: + # label depends on image (bbox, polygons, etc) + assert self.transform is None + assert self.target_transform is None + + # add optional attributes + if 'bbox' in self.output: + bbox = try_to_get(self.dataset.get_bbox, index) + if bbox: img_and_label['bbox'] = bbox + + if any(a.endswith('_map') for a in self.output): + original_polygons = try_to_get(self.dataset.get_polygons, index, toint=True) + if original_polygons is not None: + img_and_label['polygons'] = original_polygons + + img_and_label = self.img_and_target_transform(img_and_label) + + if original_polygons is not None: + transformed_polygons = img_and_label['polygons'] + + imsize = img_and_label['img'].size + if not isinstance(imsize, tuple): + imsize = imsize()[-2:][::-1] # returns h,w + + if 'label_map' in self.output: + pixlabel = self.dataset.get_label_map(index, imsize, polygons=transformed_polygons) + img_and_label['label_map'] = pixlabel.astype(int) + + # instance level attributes + for out_key in self.output: + for type in ['_instance_map', '_angle_map']: + if not out_key.endswith(type): continue + cls = out_key[:-len(type)] + get_func = getattr(self.dataset,'get'+type) + pixlabel = get_func(index, cls, imsize, polygons=transformed_polygons) + img_and_label[out_key] = pixlabel + else: + # just plain old transform, no influence on labels + + if self.transform is not None: + img_and_label['img'] = self.transform(img_and_label['img']) + + if self.target_transform: + img_and_label['label'] = self.target_transform(img_and_label['label']) + + for o in self.output: + assert img_and_label.get(o) is not None, "Missing field %s for img %s" % (o,img_filename) + return [img_and_label[o] for o in self.output] + + def __len__(self): + return len(self.dataset) + + def __repr__(self): + fmt_str = 'Dataset ' + self.dataset.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: %d\n' % len(self.dataset) + fmt_str += ' Root Location: %s\n' % self.dataset.__dict__.get('root','(unknown)') + if self.img_and_target_transform: + tmp = ' Image_and_target transforms: ' + fmt_str += '{0}{1}\n'.format(tmp, repr(self.img_and_target_transform).replace('\n', '\n' + ' ' * len(tmp))) + if self.transform: + tmp = ' Image transforms: ' + fmt_str += '{0}{1}\n'.format(tmp, repr(self.transform).replace('\n', '\n' + ' ' * len(tmp))) + if self.target_transform: + tmp = ' Target transforms: ' + fmt_str += '{0}{1}\n'.format(tmp, repr(self.target_transform).replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + + +class BalancedSampler (data.sampler.Sampler): + """ Data sampler that will provide an equal number of each class + to the network. + + size: float in [0,2] + specify the size increase/decrease w.r.t to the original dataset. + 1 means that the over-classes (with more than median n_per_class images) + will have less items, but conversely, under-classes will have more items. + + balanced: float in [0,1] + specify whether the balance constraint should be respected firmly or not. + if ==1: balance is exactly respected; if ==0, same as dataset (no change). + + use_all: bool + if True, will use all images that a class have, even if it is higher than + what the algorithm wanted to use. + """ + + def __init__(self, dataset, size=1.0, balanced=1.0, use_all=False): + assert 0 <= size <= 2 + assert 0 <= balanced <= 1 + + # enumerate class images + self.cls_imgs = [[] for i in range(dataset.nclass)] + for i in range(len(dataset)): + label = dataset.get_label(i, toint=True) + self.cls_imgs[label].append(i) + + # decide on the number of example per class + self.npc = np.percentile([len(imgs) for imgs in self.cls_imgs], max(0,min(50*size,100))) + + self.balanced = balanced + self.use_all = use_all + + self.nelem = int(0.5 + self.npc * dataset.nclass) # initial estimate + + def __iter__(self): + indices = [] + for i,imgs in enumerate(self.cls_imgs): + np.random.shuffle(imgs) # randomize + + # target size for this class + # target = logarithmic mean + b = self.balanced + if len(imgs): + target = 2**(b*np.log2(self.npc) + (1-b)*np.log2(len(imgs))) + target = int(0.5 + target) + else: + target = 0 + if self.use_all: # use all images + target = max(target, len(imgs)) + + # augment classes until target + res = [] + while len(res) < target: + res += imgs + res = res[:target] # cut + + indices += res + + np.random.shuffle(indices) + self.nelem = len(indices) + return iter(indices) + + def __len__(self): + return self.nelem + + + + +### Helper functions with get_loader() and DatasetLoader() + +def load_one_img( loader ): + ''' Helper to iterate on get_loader() + + loader: output of get_loader() + ''' + iterator = iter(loader) + batch = [] + while iterator: + if not batch: # refill + things = next(iterator) + batch = list(zip(*[t.numpy() if torch.is_tensor(t) else t for t in things])) + yield batch.pop(0) + + +def tensor2img(tensor, model): + """ convert a numpy tensor to a PIL Image + (undo the ToTensor() and Normalize() transforms) + """ + mean = model.preprocess['mean'] + std = model.preprocess['std'] + if not isinstance(tensor, np.ndarray): + if not isinstance(tensor, torch.Tensor): + tensor = tensor.data + tensor = tensor.squeeze().cpu().numpy() + + res = np.uint8(np.clip(255*((tensor.transpose(1,2,0) * std) + mean), 0, 255)) + + from PIL import Image + return Image.fromarray(res) + + +def test_loader_speed(loader_): + ''' Test the speed of a data loader + ''' + from tqdm import tqdm + loader = load_one_img(loader_) + for _ in tqdm(loader): + pass + pdb.set_trace() + + + +def try_to_get(func, *args, **kwargs): + try: + return func(*args, **kwargs) + except NotImplementedError: + return None diff --git a/dirtorch/utils/transforms.py b/dirtorch/utils/transforms.py new file mode 100644 index 0000000..380be12 --- /dev/null +++ b/dirtorch/utils/transforms.py @@ -0,0 +1,822 @@ +import pdb +import numpy as np +from PIL import Image, ImageOps +import torchvision.transforms as tvf +import random +from math import ceil + +from . import transforms_tools as F + + +def create(cmd_line, to_tensor=False, **vars): + ''' Create a sequence of transformations. + + cmd_line: (str) + Comma-separated list of transformations. + Ex: "Rotate(10), Scale(256)" + + to_tensor: (bool) + Whether to add the "ToTensor(), Normalize(mean, std)" + automatically to the end of the transformation string + + vars: (dict) + dictionary of global variables. + ''' + if to_tensor: + if not cmd_line: + cmd_line = "ToTensor(), Normalize(mean=mean, std=std)" + elif to_tensor and 'ToTensor' not in cmd_line: + cmd_line += ", ToTensor(), Normalize(mean=mean, std=std)" + + assert isinstance(cmd_line, str) + + cmd_line = "tvf.Compose([%s])" % cmd_line + try: + return eval(cmd_line, globals(), vars) + except Exception as e: + raise SyntaxError("Cannot interpret this transform list: %s\nReason: %s" % (cmd_line, e)) + + + + +class Identity (object): + """ Identity transform. It does nothing! + """ + def __call__(self, inp): + return inp + + +class PadBad (object): + def __init__(self, size=None, color=(127,127,127)): + print('Warning! The Pad class has a bug') + self.size = size + assert len(color) == 3 + if not all(isinstance(c,int) for c in color): + color = tuple([int(255*c) for c in color]) + self.color = color + + def __call__(self, inp): + img = F.grab_img(inp) + w, h = img.size + s = self.size or max(w,h) + + if (s,s) != img.size: + img2 = Image.new('RGB', (s,s), self.color) + img2.paste(img, (0,0)) + img = img2 + + return F.update_img_and_labels(inp, img, aff=(1,0,0,0,1,0)) + + +class Pad(object): + """ Pads the shortest side of the image to a given size + + If size is shorter than the shortest image, then the image will be untouched + """ + + def __init__(self, size, color=(127,127,127)): + self.size = size + assert len(color) == 3 + if not all(isinstance(c,int) for c in color): + color = tuple([int(255*c) for c in color]) + self.color = color + + def __call__(self, inp): + img = F.grab_img(inp) + w, h = img.size + if w >= h: + newh = max(h,self.size) + neww = w + else: + newh = h + neww = max(w,self.size) + + if (neww,newh) != img.size: + img2 = Image.new('RGB', (neww,newh), self.color) + img2.paste(img, ((neww-w)//2,(newh-h)//2) ) + img = img2 + + return F.update_img_and_labels(inp, img, aff=(1,0,0,0,1,0)) + +class PadSquare (object): + """ Pads the image to a square size + + The dimension of the output image will be equal to size x size + + If size is None, then the image will be padded to the largest dimension + + If size is smaller than the original image size, the image will be cropped + """ + + def __init__(self, size=None, color=(127,127,127)): + self.size = size + assert len(color) == 3 + if not all(isinstance(c,int) for c in color): + color = tuple([int(255*c) for c in color]) + self.color = color + + def __call__(self, inp): + img = F.grab_img(inp) + w, h = img.size + s = self.size or max(w, h) + + + if (s,s) != img.size: + img2 = Image.new('RGB', (s,s), self.color) + img2.paste(img, ((s-w)//2,(s-h)//2) ) + img = img2 + + return F.update_img_and_labels(inp, img, aff=(1,0,0,0,1,0)) + + +class RandomBorder (object): + """ Expands the image with a random size border + """ + + def __init__(self, min_size, max_size, color=(127,127,127)): + assert isinstance(min_size, int) and min_size >= 0 + assert isinstance(max_size, int) and min_size <= max_size + self.min_size = min_size + self.max_size = max_size + assert len(color) == 3 + if not all(isinstance(c,int) for c in color): + color = tuple([int(255*c) for c in color]) + self.color = color + + def __call__(self, inp): + img = F.grab_img(inp) + + bh = random.randint(self.min_size, self.max_size) + bw = random.randint(self.min_size, self.max_size) + + img = ImageOps.expand(img, border=(bw,bh,bw,bh), fill=self.color) + + return F.update_img_and_labels(inp, img, aff=(1,0,0,0,1,0)) + + +class Scale (object): + """ Rescale the input PIL.Image to a given size. + Same as torchvision.Scale + + The smallest dimension of the resulting image will be = size. + + if largest == True: same behaviour for the largest dimension. + + if not can_upscale: don't upscale + if not can_downscale: don't downscale + """ + def __init__(self, size, interpolation=Image.BILINEAR, largest=False, can_upscale=True, can_downscale=True): + assert isinstance(size, (float,int)) or (len(size) == 2) + self.size = size + if isinstance(self.size, float): + assert 0 < self.size <= 4, 'bad float self.size, cannot be outside of range ]0,4]' + self.interpolation = interpolation + self.largest = largest + self.can_upscale = can_upscale + self.can_downscale = can_downscale + + def get_params(self, imsize): + w,h = imsize + if isinstance(self.size, int): + is_smaller = lambda a,b: (a>=b) if self.largest else (a<=b) + if (is_smaller(w, h) and w == self.size) or (is_smaller(h, w) and h == self.size): + ow, oh = w, h + elif is_smaller(w, h): + ow = self.size + oh = int(0.5 + self.size * h / w) + else: + oh = self.size + ow = int(0.5 + self.size * w / h) + + elif isinstance(self.size, float): + ow, oh = int(0.5 + self.size*w), int(0.5 + self.size*h) + + else: # tuple of ints + ow, oh = self.size + return ow, oh + + def __call__(self, inp): + img = F.grab_img(inp) + w, h = img.size + + size2 = ow,oh = self.get_params(img.size) + + if size2 != img.size: + a1, a2 = img.size, size2 + if (self.can_upscale and min(a1) < min(a2)) or (self.can_downscale and min(a1) > min(a2)): + img = img.resize(size2, self.interpolation) + + return F.update_img_and_labels(inp, img, aff=(ow/w,0,0,0,oh/h,0)) + + + +class RandomScale (Scale): + """Rescale the input PIL.Image to a random size. + + Args: + min_size (int): min size of the smaller edge of the picture. + max_size (int): max size of the smaller edge of the picture. + + ar (float or tuple): + max change of aspect ratio (width/height). + + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + """ + + def __init__(self, min_size, max_size, ar=1, can_upscale=False, can_downscale=True, interpolation=Image.BILINEAR, largest=False): + Scale.__init__(self, 0, can_upscale=can_upscale, can_downscale=can_downscale, interpolation=interpolation, largest=largest) + assert isinstance(min_size, int) and min_size >= 1 + assert isinstance(max_size, int) and min_size <= max_size + self.min_size = min_size + self.max_size = max_size + if type(ar) in (float,int): ar = (min(1/ar,ar),max(1/ar,ar)) + assert 0.2 < ar[0] <= ar[1] < 5 + self.ar = ar + self.largest = largest + + def get_params(self, imsize): + w,h = imsize + if self.can_upscale: + max_size = self.max_size + else: + max_size = min(self.max_size,min(w,h)) + size = max(min(int(0.5 + F.rand_log_uniform(self.min_size,self.max_size)), self.max_size), self.min_size) + ar = F.rand_log_uniform(*self.ar) # change of aspect ratio + + if not self.largest: + if w < h : # image is taller + ow = size + oh = int(0.5 + size * h / w / ar) + if oh < self.min_size: + ow,oh = int(0.5 + ow*float(self.min_size)/oh),self.min_size + else: # image is wider + oh = size + ow = int(0.5 + size * w / h * ar) + if ow < self.min_size: + ow,oh = self.min_size,int(0.5 + oh*float(self.min_size)/ow) + assert ow >= self.min_size + assert oh >= self.min_size + else: # if self.largest + if w > h: # image is wider + ow = size + oh = int(0.5 + size * h / w / ar) + else: # image is taller + oh = size + ow = int(0.5 + size * w / h * ar) + assert ow <= self.max_size + assert oh <= self.max_size + + return ow, oh + + +class RandomCrop (object): + """Crop the given PIL Image at a random location. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + padding (int or sequence, optional): Optional padding on each border + of the image. Default is 0, i.e no padding. If a sequence of length + 4 is provided, it is used to pad left, top, right, bottom borders + respectively. + """ + + def __init__(self, size, padding=0): + if isinstance(size, int): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + @staticmethod + def get_params(img, output_size): + w, h = img.size + th, tw = output_size + assert h >= th and w >= tw, "Image of %dx%d is too small for crop %dx%d" % (w,h,tw,th) + + y = np.random.randint(0, h - th) if h > th else 0 + x = np.random.randint(0, w - tw) if w > tw else 0 + return x, y, tw, th + + def __call__(self, inp): + img = F.grab_img(inp) + + padl = padt = 0 + if self.padding > 0: + if F.is_pil_image(img): + img = ImageOps.expand(img, border=self.padding, fill=0) + else: + assert isinstance(img, F.DummyImg) + img = img.expand(border=self.padding) + if isinstance(self.padding, int): + padl = padt = self.padding + else: + padl, padt = self.padding[0:2] + + i, j, tw, th = self.get_params(img, self.size) + img = img.crop((i, j, i+tw, j+th)) + + return F.update_img_and_labels(inp, img, aff=(1,0,padl-i,0,1,padt-j)) + + + +class CenterCrop (RandomCrop): + """Crops the given PIL Image at the center. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + @staticmethod + def get_params(img, output_size): + w, h = img.size + th, tw = output_size + y = int(0.5 +((h - th) / 2.)) + x = int(0.5 +((w - tw) / 2.)) + return x, y, tw, th + + + +class CropToBbox(object): + """ Crop the image according to the bounding box. + + margin (float): + ensure a margin around the bbox equal to (margin * min(bbWidth,bbHeight)) + + min_size (int): + result cannot be smaller than this size + """ + def __init__(self, margin=0.5, min_size=0): + self.margin = margin + self.min_size = min_size + + def __call__(self, inp): + img = F.grab_img(inp) + w, h = img.size + + assert min(w,h) >= self.min_size + + x0,y0,x1,y1 = inp['bbox'] + assert x0 < x1 and y0 < y1, pdb.set_trace() + bbw, bbh = x1-x0, y1-y0 + margin = int(0.5 + self.margin * min(bbw, bbh)) + + i = max(0, x0 - margin) + j = max(0, y0 - margin) + w = min(w, x1 + margin) - i + h = min(h, y1 + margin) - j + + if w < self.min_size: + i = max(0, i-(self.min_size-w)//2) + w = self.min_size + if h < self.min_size: + j = max(0, j-(self.min_size-h)//2) + h = self.min_size + + img = img.crop((i,j,i+w,j+h)) + + return F.update_img_and_labels(inp, img, aff=(1,0,-i,0,1,-j)) + + + +class RandomRotation(object): + """Rescale the input PIL.Image to a random size. + + Args: + degrees (float): + rotation angle. + + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + """ + + def __init__(self, degrees, interpolation=Image.BILINEAR): + self.degrees = degrees + self.interpolation = interpolation + + def __call__(self, inp): + img = F.grab_img(inp) + w, h = img.size + + angle = np.random.uniform(-self.degrees, self.degrees) + + img = img.rotate(angle, resample=self.interpolation) + w2, h2 = img.size + + aff = F.aff_translate(-w/2,-h/2) + aff = F.aff_mul(aff, F.aff_rotate(-angle * np.pi/180)) + aff = F.aff_mul(aff, F.aff_translate(w2/2,h2/2)) + return F.update_img_and_labels(inp, img, aff=aff) + + +class RandomFlip (object): + """Randomly flip the image. + """ + def __call__(self, inp): + img = F.grab_img(inp) + w, h = img.size + + flip = np.random.rand() < 0.5 + if flip: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + return F.update_img_and_labels(inp, img, aff=(-1,0,w-1,0,1,0)) + + + +class RandomTilting(object): + """Apply a random tilting (left, right, up, down) to the input PIL.Image + + Args: + maginitude (float): + maximum magnitude of the random skew (value between 0 and 1) + directions (string): + tilting directions allowed (all, left, right, up, down) + examples: "all", "left,right", "up-down-right" + """ + + def __init__(self, magnitude, directions='all'): + self.magnitude = magnitude + self.directions = directions.lower().replace(',',' ').replace('-',' ') + + def __call__(self, inp): + img = F.grab_img(inp) + w, h = img.size + + x1,y1,x2,y2 = 0,0,h,w + original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)] + + max_skew_amount = max(w, h) + max_skew_amount = int(ceil(max_skew_amount * self.magnitude)) + skew_amount = random.randint(1, max_skew_amount) + + if self.directions == 'all': + choices = [0,1,2,3] + else: + dirs = ['left', 'right', 'up', 'down'] + choices = [] + for d in self.directions.split(): + try: + choices.append(dirs.index(d)) + except: + raise ValueError('Tilting direction %s not recognized' % d) + + skew_direction = random.choice(choices) + + if skew_direction == 0: + # Left Tilt + new_plane = [(y1, x1 - skew_amount), # Top Left + (y2, x1), # Top Right + (y2, x2), # Bottom Right + (y1, x2 + skew_amount)] # Bottom Left + elif skew_direction == 1: + # Right Tilt + new_plane = [(y1, x1), # Top Left + (y2, x1 - skew_amount), # Top Right + (y2, x2 + skew_amount), # Bottom Right + (y1, x2)] # Bottom Left + elif skew_direction == 2: + # Forward Tilt + new_plane = [(y1 - skew_amount, x1), # Top Left + (y2 + skew_amount, x1), # Top Right + (y2, x2), # Bottom Right + (y1, x2)] # Bottom Left + elif skew_direction == 3: + # Backward Tilt + new_plane = [(y1, x1), # Top Left + (y2, x1), # Top Right + (y2 + skew_amount, x2), # Bottom Right + (y1 - skew_amount, x2)] # Bottom Left + + # To calculate the coefficients required by PIL for the perspective skew, + # see the following Stack Overflow discussion: https://goo.gl/sSgJdj + matrix = [] + + for p1, p2 in zip(new_plane, original_plane): + matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]) + matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) + + A = np.matrix(matrix, dtype=np.float) + B = np.array(original_plane).reshape(8) + + homography = np.dot(np.linalg.pinv(A), B) + homography = tuple(np.array(homography).reshape(8)) + + img = img.transform(img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC) + + homography = np.linalg.pinv(np.float32(homography+(1,)).reshape(3,3)).ravel()[:8] + return F.update_img_and_labels(inp, img, persp=tuple(homography)) + + + +class StillTransform (object): + """ Takes and return an image, without changing its shape or geometry. + """ + def _transform(self, img): + raise NotImplementedError() + + def __call__(self, inp): + img = F.grab_img(inp) + + # transform the image (size should not change) + img = self._transform(img) + + return F.update_img_and_labels(inp, img, aff=(1,0,0,0,1,0)) + + + +class ColorJitter (StillTransform): + """Randomly change the brightness, contrast and saturation of an image. + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + Arguments are same as that of __init__. + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ + transforms = [] + if brightness > 0: + brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) + transforms.append(tvf.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) + + if contrast > 0: + contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) + transforms.append(tvf.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) + + if saturation > 0: + saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) + transforms.append(tvf.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) + + if hue > 0: + hue_factor = np.random.uniform(-hue, hue) + transforms.append(tvf.Lambda(lambda img: F.adjust_hue(img, hue_factor))) + + np.random.shuffle(transforms) + transform = tvf.Compose(transforms) + + return transform + + def _transform(self, img): + transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + return transform(img) + + +class RandomErasing (StillTransform): + """ + Class that performs Random Erasing, an augmentation technique described + in `https://arxiv.org/abs/1708.04896 `_ + by Zhong et al. To quote the authors, random erasing: + + "*... randomly selects a rectangle region in an image, and erases its + pixels with random values.*" + + The size of the random rectangle is controlled using the + :attr:`area` parameter. This area is random in its + width and height. + + Args: + area: The percentage area of the image to occlude. + """ + def __init__(self, area): + self.area = area + + def _transform(self, image): + """ + Adds a random noise rectangle to a random area of the passed image, + returning the original image with this rectangle superimposed. + + :param image: The image to add a random noise rectangle to. + :type image: PIL.Image + :return: The image with the superimposed random rectangle as type + image PIL.Image + """ + w, h = image.size + + w_occlusion_max = int(w * self.area) + h_occlusion_max = int(h * self.area) + + w_occlusion_min = int(w * self.area/2) + h_occlusion_min = int(h * self.area/2) + + if not (w_occlusion_min < w_occlusion_max and h_occlusion_min < h_occlusion_max): + return image + w_occlusion = np.random.randint(w_occlusion_min, w_occlusion_max) + h_occlusion = np.random.randint(h_occlusion_min, h_occlusion_max) + + if len(image.getbands()) == 1: + rectangle = Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion) * 255)) + else: + rectangle = Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion, len(image.getbands())) * 255)) + + assert w > w_occlusion and h > h_occlusion, pdb.set_trace() + random_position_x = np.random.randint(0, w - w_occlusion) + random_position_y = np.random.randint(0, h - h_occlusion) + + image = image.copy() # don't modify the original + image.paste(rectangle, (random_position_x, random_position_y)) + + return image + + +class ToTensor (StillTransform, tvf.ToTensor): + def _transform(self, img): + return tvf.ToTensor.__call__(self, img) + +class Normalize (StillTransform, tvf.Normalize): + def _transform(self, img): + return tvf.Normalize.__call__(self, img) + + +class BBoxToPixelLabel (object): + """ Convert a bbox into per-pixel label + """ + def __init__(self, nclass, downsize, mode): + self.nclass = nclass + self.downsize = downsize + self.mode = mode + self.nbin = 5 + self.log_scale = 1.5 + self.ref_scale = 8.0 + + def __call__(self, inp): + assert isinstance(inp, dict) + + w, h = inp['img'].size + ds = self.downsize + assert w % ds == 0 + assert h % ds == 0 + + x0,y0,x1,y1 = inp['bbox'] + inp['bbox'] = np.int64(inp['bbox']) + + ll = x0/ds + rr = (x1-1)/ds + tt = y0/ds + bb = (y1-1)/ds + l = max(0, int(ll)) + r = min(w//ds, 1+int(rr)) + t = max(0, int(tt)) + b = min(h//ds, 1+int(bb)) + inp['bbox_downscaled'] = np.array((l,t,r,b), dtype=np.int64) + + W, H = w//ds, h//ds + res = np.zeros((H,W), dtype=np.int64) + res[:] = self.nclass # last bin is null class + res[t:b, l:r] = inp['label'] + inp['pix_label'] = res + + if self.mode == 'hough': + # compute hough parameters + topos = lambda left, pos, right: np.floor( self.nbin * (pos - left) / (right - left) ) + def tolog(size): + size = max(size,1e-8) # make it positive + return np.round( np.log(size / self.ref_scale) / np.log(self.log_scale) + (self.nbin-1)/2 ) + + # for each pixel, find its x and y position + yc,xc = np.mgrid[0:H, 0:W] + res = -np.ones((4, H, W), dtype=np.int64) + res[0] = topos(ll, xc, rr) + res[1] = topos(tt, yc, bb) + res[2] = tolog(rr - ll) + res[3] = tolog(bb - tt) + res = np.clip(res, 0, self.nbin-1) + inp['pix_bbox_hough'] = res + + elif self.mode == 'regr': + topos = lambda left, pos, right: (pos - left) / (right - left) + def tolog(size): + size = max(size,1) # make it positive + return np.log(size / self.ref_scale) / np.log(self.log_scale) + + # for each pixel, find its x and y position + yc,xc = np.float64(np.mgrid[0:H, 0:W]) + 0.5 + res = -np.ones((4, H, W), dtype=np.float32) + res[0] = topos(ll, xc, rr) + res[1] = topos(tt, yc, bb) + res[2] = tolog(rr - ll) + res[3] = tolog(bb - tt) + inp['pix_bbox_regr'] = res + + else: + raise NotImplementedError() + + return inp + + + + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser("Script to try out and visualize transformations") + parser.add_argument('--img', type=str, default='$HERE/test.png', help='input image') + parser.add_argument('--trfs', type=str, required=True, help='sequence of transformations') + + parser.add_argument('--bbox', action='store_true', help='add a bounding box') + parser.add_argument('--polygons', action='store_true', help='add a polygon') + + parser.add_argument('--input_size', type=int, default=224, help='optional param') + parser.add_argument('--layout', type=int, nargs=2, default=(3,3), help='Number of rows and columns') + + args = parser.parse_args() + + import os + args.img = args.img.replace('$HERE',os.path.dirname(__file__)) + img = Image.open(args.img) + + if args.bbox or args.polygons: + img = dict(img=img) + + if args.bbox: + w, h = img['img'].size + img['bbox'] = (w/4,h/4,3*w/4,3*h/4) + if args.polygons: + w, h = img['img'].size + img['polygons'] = [(1,[(w/4,h/4),(w/2,h/4),(w/4,h/2)])] + + trfs = create(args.trfs, input_size=args.input_size) + + from matplotlib import pyplot as pl + pl.ion() + pl.subplots_adjust(0,0,1,1) + + nr,nc = args.layout + + while True: + for j in range(nr): + for i in range(nc): + pl.subplot(nr,nc,i+j*nc+1) + if i==j==0: + img2 = img + else: + img2 = trfs(img.copy()) + if isinstance(img2, dict): + if 'bbox' in img2: + l,t,r,b = img2['bbox'] + x,y = [l,r,r,l,l], [t,t,b,b,t] + pl.plot(x,y,'--',lw=5) + if 'polygons' in img2: + for label, pts in img2['polygons']: + x,y = zip(*pts) + pl.plot(x,y,'-',lw=5) + img2 = img2['img'] + pl.imshow(img2) + pl.xlabel("%d x %d" % img2.size) + pl.xticks(()) + pl.yticks(()) + pdb.set_trace() + + +''' + +python -m utils.transforms --trfs "Scale(384), ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1), RandomErasing(0.5), RandomRotation(10), RandomTilting(0.5, 'all'), RandomScale(240,320), RandomCrop(input_size)" --polygons +''' + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dirtorch/utils/transforms_tools.py b/dirtorch/utils/transforms_tools.py new file mode 100644 index 0000000..ebcf839 --- /dev/null +++ b/dirtorch/utils/transforms_tools.py @@ -0,0 +1,251 @@ +import pdb +import numpy as np +from PIL import Image, ImageOps, ImageEnhance + + +def is_pil_image(img): + return isinstance(img, Image.Image) + +class DummyImg: + ''' This class is a dummy image only defined by its size. + ''' + def __init__(self, size): + self.size = size + + def resize(self, size, *args, **kwargs): + return DummyImg(size) + + def expand(self, border): + w, h = self.size + if isinstance(border, int): + size = (w+2*border, h+2*border) + else: + l,t,r,b = border + size = (w+l+r, h+t+b) + return DummyImg(size) + + def crop(self, border): + w, h = self.size + l,t,r,b = border + assert 0 <= l <= r <= h + assert 0 <= t <= b <= h + size = (r-l, b-t) + return DummyImg(size) + + def rotate(self, angle): + raise NotImplementedError + + def transform(self, size, *args, **kwargs): + return DummyImg(size) + + +def grab_img( img_and_label ): + ''' Called to extract the image from an img_and_label input + (a dictionary). Also compatible with old-style PIL images. + ''' + if isinstance(img_and_label, dict): + # if input is a dictionary, then + # it must contains the img or its size. + try: + return img_and_label['img'] + except KeyError: + return DummyImg(img_and_label['imsize']) + + else: + # or it must be the img directly + return img_and_label + + +def update_img_and_labels(img_and_label, img, aff=None, persp=None): + ''' Called to update the img_and_label + ''' + if isinstance(img_and_label, dict): + img_and_label['img'] = img + + if 'bbox' in img_and_label: + l,t,r,b = img_and_label['bbox'] + corners = [(l,t),(l,b),(r,b),(r,t)] + if aff: + pts = [aff_mul(aff, pt) for pt in corners] + elif persp: + pts = [persp_mul(persp, pt) for pt in corners] + else: + raise NotImplementedError() + x,y = map(list,zip(*pts)) + x.sort() + y.sort() + l, r = np.mean(x[:2]), np.mean(x[2:]) + t, b = np.mean(y[:2]), np.mean(y[2:]) + img_and_label['bbox'] = int_tuple(l,t,r,b) + + if 'polygons' in img_and_label: + polygons = [] + for label,pts in img_and_label['polygons']: + if aff: + pts = [int_tuple(*aff_mul(aff, pt)) for pt in pts] + elif persp: + pts = [int_tuple(*persp_mul(persp, pt)) for pt in pts] + else: + raise NotImplementedError() + polygons.append((label, pts)) + img_and_label['polygons'] = polygons + + return img_and_label + + else: + # or it must be the img directly + return img + + +def rand_log_uniform(a, b): + return np.exp(np.random.uniform(np.log(a),np.log(b))) + + +def int_tuple(*args): + return tuple(map(int,args)) + +def aff_translate(tx, ty): + return (1,0,tx, + 0,1,ty) + +def aff_rotate(angle): + return (np.cos(angle),-np.sin(angle), 0, + np.sin(angle), np.cos(angle), 0) + +def aff_mul(aff, aff2): + ''' affine multiplication. + aff: 6-tuple (affine transform) + aff2: 6-tuple (affine transform) or 2-tuple (point) + ''' + assert isinstance(aff, tuple) + assert isinstance(aff2, tuple) + aff = np.array(aff+(0,0,1)).reshape(3,3) + + if len(aff2) == 6: + aff2 = np.array(aff2+(0,0,1)).reshape(3,3) + return tuple(np.dot(aff2, aff)[:2].ravel()) + + elif len(aff2) == 2: + return tuple(np.dot(aff2+(1,), aff.T).ravel()[:2]) + + else: + raise ValueError("bad input %s" % str(aff2)) + +def persp_mul(mat, mat2): + ''' homography (perspective) multiplication. + mat: 8-tuple (homography transform) + mat2: 8-tuple (homography transform) or 2-tuple (point) + ''' + assert isinstance(mat, tuple) + assert isinstance(mat2, tuple) + mat = np.array(mat+(1,)).reshape(3,3) + + if len(mat2) == 8: + mat2 = np.array(mat2+(1,)).reshape(3,3) + return tuple(np.dot(mat2, mat).ravel()[:8]) + + elif len(mat2) == 2: + pt = np.dot(mat2+(1,), mat.T).ravel() + pt /= pt[2] # homogeneous coordinates + return tuple(pt[:2]) + + else: + raise ValueError("bad input %s" % str(aff2)) + + + +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an Image. + Args: + img (PIL Image): PIL Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + Returns: + PIL Image: Brightness adjusted image. + """ + if not is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an Image. + Args: + img (PIL Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + Returns: + PIL Image: Contrast adjusted image. + """ + if not is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an image. + Args: + img (PIL Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + Returns: + PIL Image: Saturation adjusted image. + """ + if not is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + See https://en.wikipedia.org/wiki/Hue for more details on Hue. + Args: + img (PIL Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + Returns: + PIL Image: Hue adjusted image. + """ + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) + + if not is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert(input_mode) + return img + + + diff --git a/dirtorch/utils/watcher.py b/dirtorch/utils/watcher.py new file mode 100644 index 0000000..d35d433 --- /dev/null +++ b/dirtorch/utils/watcher.py @@ -0,0 +1,491 @@ +"""Watch is an object to easily watch what is happening +during the training/evaluation of a deep net. +""" +import os +import time +import pdb +from collections import defaultdict +import numpy as np + +from . import evaluation + + +class AverageMeter(object): + """Computes and stores the average and current value of a metric. + It is fast (constant time), regardless of the lenght of the measure series. + + mode: (str). Behavior of the meter. + 'average': just the average of all values since the start + 'sliding': just the average of the last 'nlast' values + 'last': just the last value (=='sliding' with nlast=1) + 'min' : the minimum so far + 'max' : the maximum so far + """ + + def __init__(self, mode='average', nlast=5): + self.mode = mode + self.nlast = nlast + self.reset() + + def reset(self): + self.vals = [] + self.avg = 0 + self.sum = 0 + self.count = 0 + self.is_perf = False + + def export(self): + return {k:val for k,val in self.__dict__.items() if type(val) in (bool, str, float, int, list)} + + def update(self, val, weight=1): + ''' sliding window average ''' + self.vals.append( val ) + self.sum += val * weight + self.count += weight + if self.mode == 'average': + self.avg = self.sum / self.count + elif self.mode == 'sliding': + vals = self.vals[-self.nlast:] + self.avg = sum(vals) / (1e-8+len(vals)) + elif self.mode == 'last': + self.avg = val + elif self.mode == 'min': + self.avg = min(self.avg or float('inf'), val) + elif self.mode == 'max': + self.avg = max(self.avg or -float('inf'), val) + else: + raise ValueError("unknown AverageMeter update policy '%s'" % self.mode) + + def __bool__(self): + return bool(self.count) + __nonzero__ = __bool__ # for python2 + + def __len__(self): + return len(self.vals) + + def tostr(self, name='', budget=100, unit=''): + ''' Print the meter, using more or less characters + ''' + _budget = budget + if name: + name += ': ' + budget -= len(name) + + if isinstance(self.avg, int): + avg = '%d' % self.avg + minavg = len(avg) + val = '' + budget -= len(avg) + len(unit) + else: + avg = '%f' % self.avg + minavg = (avg+'.').find('.') + + val = 'last: %f' % self.vals[-1] + minval = (val+'.').find('.') + + budget -= len(avg) + len(val) + 3 + 2*len(unit) + + while budget < 0 : + old_budget = budget + + if len(val): + val = val[:-1] + budget += 1 + if len(val) < minval: + val = '' # we cannot delete beyond the decimal point + budget += 3 + len(val) + len(unit) # add parenthesis + continue + else: + if len(val) % 2: continue # shrink the other sometimes + + if len(avg) >= minavg and len(name) <= len(avg): + avg = avg[:-1] + budget += 1 + continue # can shrink further + + if len(name) > 2: + name = name[:-2]+' ' # remove last char + budget += 1 + + # cannot shrink anymore + if old_budget == budget: break + + res = name + avg+unit + if val: res += ' (' + val+unit + ')' + res += ' '*max(0, len(res) - _budget) + return res + + +class Watch (object): + """ + Usage: + ------ + - call start() just before the loop + - call tic() at the beginning of the loop (first line) + - call eval_train(measure1=score1, measure2=score2, ...) or eval_test(...) + - call toc() and the end of the loop (last line) + - call stop() after the loop + + Arguments: + ---------- + tfreq: (float or None) + temporal frequency of outputs (in seconds) + + nfreq: (int or None) + iteration frequency of outputs (in iterations) + """ + def __init__(self, tfreq=30.0, nfreq=None): + self.tfreq = tfreq + self.nfreq = nfreq + + # init meters + self.meters = defaultdict(AverageMeter) + self.meters['epoch'] = AverageMeter(mode='last') + self.meters['test_epoch'] = AverageMeter(mode='last') + self.meters['data_time'] = AverageMeter(mode='sliding') + self.meters['batch_time'] = AverageMeter(mode='sliding') + self.meters['lr'] = AverageMeter(mode='sliding') + self.meters['loss'] = AverageMeter(mode='sliding') + + # init current status + self.tostr_t = None + self.cur_n = None + self.batch_size = None + self.last_test = 0 + self.viz = False + + def __getattr__(self, name): + meters = object.__getattribute__(self, 'meters') + if name in meters: + return meters[name] + else: + return object.__getattribute__(self, name) + + def reset(self): + for meter in self.meters.values(): + meter.reset() + + def start(self): + '''Just before the loop over batches + ''' + self.last_time = time.time() + self.cur_n = 0 + self.tostr_t = self.last_time + + def tic(self, batch_size, epoch=0, **kw): + '''Just after loading one batch + ''' + assert self.last_time is not None, "you must call start() before the loop!" + self.meters['data_time'].update(time.time() - self.last_time) + self.batch_size = batch_size + self.meters['epoch'].update(epoch) + n_epochs = len(self.meters['epoch']) + + for name, val in kw.items(): + self.meters[name].mode = 'last' + self.meters[name].update(val) + assert len(self.meters[name]) == n_epochs, "missing values for meter %s (expected %d, got %d)" % (name, n_epochs, len(self.meters[name])) + + def eval_train(self, **measures): + n_epochs = len(self.meters['epoch']) + + for name, score in measures.items(): + self.meters[name].is_perf = True + self.meters[name].update(score, self.batch_size) + assert len(self.meters[name]) == n_epochs, "missing values for meter %s (expected %d, got %d)" % (name, n_epochs, len(self.meters[name])) + + def eval_test(self, mode='average', **measures): + assert self.batch_size is None, "you must call toc() before; measures should concern the entire test" + epoch = self.meters['epoch'].avg + self.meters['test_epoch'].update(epoch) + n_epochs = len(self.meters['test_epoch']) + + for name, val in measures.items(): + name = 'test_'+name + if name not in self.meters: + self.meters[name] = AverageMeter(mode=mode) + self.meters[name].is_perf = True + self.meters[name].update(val) + assert len(self.meters[name]) == n_epochs, "missing values for meter %s (expected %d, got %d)" % (name, n_epochs, len(self.meters[name])) + + if self.viz: self.plot() + + def toc(self): + '''Just after finishing to process one batch + ''' + assert self.batch_size is not None, "you must call tic() at the begining of the loop" + + now = time.time() + self.meters['batch_time'].update(now - self.last_time) + + if (self.tfreq and now-self.tostr_t>self.tfreq) or (self.nfreq and (self.cur_n % self.nfreq) == 0): + self.tostr_t = now + n_meters = sum([bool(meter) for meter in self.meters.values()]) + cols = get_terminal_ncols() + cols_per_meter = (cols - len('Time ')) / n_meters # columns per meter + N = np.int32(np.linspace(0,cols - len('Time '), n_meters+1)) + N = list(N[1:] - N[:-1]) # this sums to the number of available columns + + tt = '' + if self.meters['epoch']: + tt += self.meters['epoch'].tostr('Epoch', budget=N.pop()-1)+' ' + tt += 'Time %s %s' % ( + self.meters['data_time'].tostr('data',budget=N.pop()-1,unit='s'), + self.meters['batch_time'].tostr('batch',budget=N.pop(),unit='s')) + for name, meter in sorted(self.meters.items()): + if name in ('epoch', 'data_time', 'batch_time'): continue + if meter: tt += ' '+meter.tostr(name, budget=N.pop()-1) + print(tt) + if self.viz: self.plot() + + self.batch_size = None + self.cur_n += 1 + self.last_time = time.time() + + def stop(self): + '''Just after all the batches have been processed + ''' + res = '' + for name, meter in sorted(self.meters.items()): + if meter.is_perf: + res += '\n * ' + meter.tostr(name) + print(res[1:]) + + def upgrade(self): + '''Upgrade the old watcher to the latest version + ''' + if not hasattr(self,'meters'): + # convert old to new format + self.meters = defaultdict(AverageMeter) + self.meters['epoch'] = AverageMeter(mode='last') + for i,name in enumerate('data_time batch_time lr loss top1 top5'.split()): + try: + self.meters[name] = getattr(self,name) + if i < 4: self.meters[name].mode = 'sliding' + delattr(self, name) + except AttributeError: + continue + if not self.meters['epoch']: + for i in range(self.epoch): + self.meters['epoch'].update(i) + + return self + + def measures(self): + return {name:meter.avg for name,meter in self.meters.items() if meter.is_perf} + + def plot(self): + ''' plot what happened so far. + ''' + import matplotlib.pyplot as pl; pl.ion() + self.upgrade() + + epochs = self.meters['epoch'].vals + test_epochs = self.meters['test_epoch'].vals + + fig = pl.figure('Watch') + pl.subplots_adjust(0.1,0.03,0.97,0.99) + done = {'epoch','test_epoch'} + + ax = pl.subplot(321) + ax.lines = [] + for name in 'data_time batch_time'.split(): + meter = self.meters[name] + if not meter: continue + done.add(name) + n = len(meter.vals) + pl.plot(epochs[:n], meter.vals, label=name) + self.crop_plot(ymin=0) + pl.legend() + + ax = pl.subplot(322) + ax.lines = [] + for name in 'lr'.split(): + meter = self.meters[name] + if not meter: continue + done.add(name) + n = len(meter.vals) + pl.plot(epochs[:n], meter.vals, label=name) + self.crop_plot(ymin=0) + pl.legend() + + def avg(arr): + from scipy.ndimage.filters import uniform_filter + return uniform_filter(arr, size=max(3,len(arr)//20), mode='nearest') + def halfc(color): + pdb.set_trace() + return tuple([c/2 for c in color]) + + ax = pl.subplot(312) + ax.lines = [] + for name in self.meters: + if not name.startswith('loss'): continue + meter = self.meters[name] + if not meter: continue + done.add(name) + n = len(meter.vals) + line = pl.plot(epochs[:n], meter.vals, ':', lw=0.5) + ax.plot(epochs[:n], avg(meter.vals), '-', label=name, color=line[0].get_color()) + self.crop_plot() + pl.legend() + + ax = pl.subplot(313) + ax.lines = [] + for name in self.meters: + if name in done: continue + meter = self.meters[name] + if not meter: continue + done.add(name) + n = len(meter.vals) + if name.startswith('test_'): + epochs_ = test_epochs[:n] + else: + epochs_ = epochs[:n] + line = ax.plot(epochs_, meter.vals, ':', lw=0.5) + ax.plot(epochs_, avg(meter.vals), '-', label=name, color=line[0].get_color()) + self.crop_plot() + pl.legend() + + pl.pause(0.01) # update the figure + + def export(self): + members = {} + for k, v in self.__dict__.items(): + if k == 'meters': + meters = {} + for k1,v1 in v.items(): + meters[k1] = v1.export() + members[k] = meters + else: + members[k] = v + return members + + @staticmethod + def update_all(checkpoint): + watch = Watch() + for k,v in checkpoint.items(): + if 'meters' in k: + meters = defaultdict(AverageMeter) + for k1,v1 in v.items(): + meter = AverageMeter() + meter.__dict__.update(v1) + meters[k1] = meter + watch.__dict__[k] = meters + else: + watch.__dict__[k] = v + return watch + + @staticmethod + def crop_plot(span=0.5, ax=None, xmin=np.inf, xmax=-np.inf, ymin=np.inf, ymax=-np.inf): + import matplotlib.pyplot as pl + if ax is None: ax=pl.gca() + if not ax.lines: return # nothing to do + + # set xlim to the last of all data + for l in ax.lines: + x,y = map(np.asarray, l.get_data()) + xmin = min(xmin,np.min(x[np.isfinite(x)])) + xmax = max(xmax,np.max(x[np.isfinite(x)])) + xmin = xmax - span*(xmax-xmin) + + # set ylim to the span of remaining points + for l in ax.lines: + x,y = map(np.asarray, l.get_data()) + y = y[(x>=xmin) & (x<=xmax) & np.isfinite(y)] # select only relevant points + if y.size == 0: continue + ymin = min(ymin,np.min(y)) + ymax = max(ymax,np.max(y)) + + try: + ax.set_xlim(xmin,xmax+1) + ax.set_ylim(ymin,(ymax+1e-8)*1.01) + except ValueError: + pass #pdb.set_trace() + + +class TensorBoard (object): + """Tensorboard to plot training and validation loss and others + + .. notes:: + + ```shell + conda install -c conda-forge tensorboardx + conda install tensorflow + ``` + + Args: + logdir (str): path to save log + phases (array): phases to plot, e.g., ['train', 'val'] + """ + def __init__(self, logdir, phases): + from tensorboardX import SummaryWriter + if not os.path.exists(logdir): + for key in phases: + os.makedirs(os.path.join(logdir, key)) + + self.phases = phases + self.tb_writer={} + for key in phases: + self.tb_writer[key] = SummaryWriter(os.path.join(logdir, key)) + + def add_scalars(self, phase, watch, names): + """ Add scalar values in watch.meters[names] + """ + if not phase in self.phases: + raise AttributeError('%s is unknown'%phase) + + epochs = sorted(watch.meters['epoch'].vals) + for name in names: + vals = sorted(watch.meters[name].vals) + cnt = watch.meters[name].count + for n, val in zip(epochs, vals): + self.tb_writer[phase].add_scalar(name, val, n*cnt) + + def close(): + for key in self.phases: + self.tb_writer[key].close() + + + +def get_terminal_ncols(default=160): + try: + import sys + from termios import TIOCGWINSZ + from fcntl import ioctl + from array import array + except ImportError: + return default + else: + try: + return array('h', ioctl(sys.stdout, TIOCGWINSZ, '\0' * 8))[1] + except: + try: + from os.environ import get + except ImportError: + return default + else: + return int(get('COLUMNS', 1)) - 1 + + + +if __name__ == '__main__': + import time + + # test printing size + batch_size = 256 + + watch = Watch(tfreq=0.5) + watch.start(epoch=0) + watch.meters['top1'].is_perf = True + watch.meters['top5'].is_perf = True + + for epoch in range(99999): + watch.tic(batch_size) + time.sleep(0.1) + watch.meters['top1'].update(1-np.exp(-epoch/10)) + watch.meters['top5'].update(1-np.exp(-epoch/5)) + watch.toc(loss=np.sin(epoch/10), lr=np.cos(epoch/20)) + + watch.stop() + + pdb.set_trace()