Skip to content

Commit

Permalink
added visualizations
Browse files Browse the repository at this point in the history
  • Loading branch information
qbilius committed Dec 1, 2022
1 parent 59b1bdc commit 195e4bc
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 42 deletions.
52 changes: 40 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,53 @@
# Amodal completion with transformers


## Description

An experiment to see if we can train [ViT](https://arxiv.org/abs/2010.11929) to output amodally completed shapes. This ViT implementation is based on [PyTorch Lightning Tutorial 11](https://pytorch-lightning.readthedocs.io/en/stable/notebooks/course_UvA-DL/11-vision-transformer.html)
An experiment to see if we can train [ViT](https://arxiv.org/abs/2010.11929) to output amodally completed shape. Amodal completion is a perceptual phenomenon where shapes occluded by other shapes appear to us as complete. For instance, if a disk occludes part of a rectangle, we still perceive that rectangle as a rectangle rather than some odd shape that has a small portion missing.

Here we hypothesize that training a ViT to output full shapes that are behind an occluder (these are our targets that a model is learning to predict) is a sufficient signal to learn amodal completion. Caveat -– only rectangles and discs used, so our results may not generalize to more complex scenes.

This ViT implementation is based on [PyTorch Lightning Tutorial 11](https://pytorch-lightning.readthedocs.io/en/stable/notebooks/course_UvA-DL/11-vision-transformer.html).


## How to run

First, install dependencies
1. Install dependencies

```bash
# clone project
git clone https://github.com/qbilius/amodal

# install project
cd amodal
pip install -e .
pip install -r requirements.txt
```

2. Run training locally: `python amodal/train.py`.
3. Observe results with tensorboard: `tensorboard --logdir=output`.
4. Visualize loss with `python amodal/visualization.py plot_loss --version <version number>`.
5. Visualize amodal completion results with `python amodal/visualization.py plot_results --version <version number>`.

## Details

- Architecture:
- Image embedding into a 64-dimensional space
- Positional encoding, sampled from a normal distribution
- 4 transformer layers with 128-dimensional hidden layers
- A final fully-connected prediction layer that de-embeds outputs back into an image space
- Optimizer: SDG with learning rate = .1 and momentum .9
- Training: 150 training epochs on a dataset of 50k examples (~2 hours)


## Results

[Checkpoint](https://github.com/qbilius/amodal/releases/download/v1.0.0/last.ckpt) - [Parameters](https://github.com/qbilius/amodal/releases/download/v1.0.0/hparams.yaml) - [Log](https://github.com/qbilius/amodal/releases/download/v1.0.0/events.out.tfevents)

```bash
# clone project
git clone https://github.com/qbilius/amodal
![](results/loss.png)

# install project
cd amodal
pip install -e .
pip install -r requirements.txt
```
![](results/results.png)

Now, run training locally: `python amodal/train.py`
Observe results with tensorboard: `tensorboard --logdir=output`

## License

Expand Down
21 changes: 12 additions & 9 deletions amodal/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from amodal import dataset, model


DATA_PATH = Path('/Users/qbilius/datasets/amodal/')
OUTPUT_PATH = Path('/Users/qbilius/datasets/amodal/')


class DataModule(pl.LightningDataModule):
Expand All @@ -24,7 +24,7 @@ def __init__(self, batch_size=100, image_size=16, patch_size=4):
def _get_dataloader(self, size: int, stage: str):
seed = hash(stage) % 2**32
data = dataset.SVGDataset(
data_file=DATA_PATH / f'{stage}.npy',
data_file=OUTPUT_PATH / f'{stage}.npy',
size=size,
image_size=self.image_size,
patch_size=self.patch_size,
Expand Down Expand Up @@ -119,27 +119,30 @@ def main():
pl.seed_everything(1, workers=True)

parser = argparse.ArgumentParser()
parser.add_argument('--output_path', default='output', action='store_true')
parser.add_argument('--output_path', default=OUTPUT_PATH)
parser.add_argument('--test', default=False, action='store_true')
parser = DataModule.add_argparse_args(parser)
parser = Model.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()

output_path = Path(args.output_path)

if args.test:
version = 'test'
path = Path(args.output_path) / version
if path.exists():
shutil.rmtree(path)
log_dir = output_path / version
if log_dir.exists():
shutil.rmtree(log_dir)
else:
version = None
version = max([int(p.stem.split('_')[1]) for p in output_path.glob('version_*')]) + 1
log_dir = output_path / f'version_{version}'

data = DataModule.from_argparse_args(args)
model = Model(**vars(args))

callbacks = [
pl.callbacks.ModelCheckpoint(
dirpath=DATA_PATH,
dirpath=log_dir,
save_top_k=0,
save_last=True
)
Expand All @@ -152,7 +155,7 @@ def main():
# enable_checkpointing=False,
max_epochs=150,
val_check_interval=500,
logger=pl.loggers.TensorBoardLogger(save_dir=args.output_path, name='', version=version),
logger=pl.loggers.TensorBoardLogger(save_dir=output_path, name='', version=version),
callbacks=callbacks
)
trainer.fit(model, data)
Expand Down
77 changes: 56 additions & 21 deletions amodal/visualization.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,65 @@
from pathlib import Path
import matplotlib.pyplot as plt

import fire
import torch
import tbparse
import matplotlib.pyplot as plt

from amodal import train, dataset


model = train.Model.load_from_checkpoint(train.DATA_PATH / 'last.ckpt')
data = dataset.SVGDataset(train.DATA_PATH / 'val.npy')
generator = dataset.OverlappingShapes(image_size=16)
def plot_loss(output_path=train.OUTPUT_PATH, version=1, savefig=False):
log_dir = Path(output_path) / f'version_{version}'
tfevents = list(log_dir.glob('*tfevents*'))[0]
df = tbparse.SummaryReader(str(tfevents)).scalars
pv = (df
.loc[df.tag.isin(['train_loss', 'val_loss'])]
.pivot(index='step', columns='tag', values='value')
)
pv.train_loss.plot()
pv.val_loss.dropna().plot()
plt.legend()
plt.show()

if savefig:
plt.savefig('results/loss.png', bbox_inches='tight', transparent=True)


def plot_results(output_path=train.OUTPUT_PATH, version=1, savefig=False, seed=None):
model = train.Model.load_from_checkpoint(str(output_path / f'version_{version}' / 'last.ckpt'))
data = dataset.SVGDataset(train.OUTPUT_PATH / 'val.npy', seed=seed)
generator = dataset.OverlappingShapes(image_size=16)

fig, axes = plt.subplots(nrows=2, ncols=5, sharex=True, sharey=True, figsize=(6, 2))
plt.suptitle('Input image (above) and predicted amodal completion (below)')
for ax1, ax2 in axes.T:
img, _ = generator()
x = data.transform(img).unsqueeze(0)
x = model.model.img_to_patch(x)

out = model(x)
y_hat = (model.model
.patch_to_img(torch.sigmoid(out))
.repeat([1, 3, 1, 1])
.permute(0, 2, 3, 1)
.detach()
.numpy()
[0]
)

ax1.imshow(img)
ax1.get_xaxis().set_ticks([])
ax1.get_yaxis().set_ticks([])

ax2.imshow(y_hat)
ax2.get_xaxis().set_ticks([])
ax2.get_yaxis().set_ticks([])

fig, axes = plt.subplots(nrows=2, ncols=5, sharex=True, sharey=True, figsize=(6, 2))
for ax1, ax2 in axes.T:
img, _ = generator()
x = data.transform(img).unsqueeze(0)
x = model.model.img_to_patch(x)
plt.show()

out = model(x)
y_hat = (model.model
.patch_to_img(torch.sigmoid(out))
.repeat([1, 3, 1, 1])
.permute(0, 2, 3, 1)
.detach()
.numpy()
[0]
)
if savefig:
plt.savefig('results/results.png', bbox_inches='tight', transparent=True)

ax1.imshow(img)
ax2.imshow(y_hat)

plt.show()
if __name__ == '__main__':
fire.Fire()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
fire
torch
torchvision
pytorch-lightning
Expand Down
Binary file added results/loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 195e4bc

Please sign in to comment.