Skip to content

Commit

Permalink
Add helper function for datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
tristandeleu committed Jan 29, 2020
1 parent fde633c commit 3c0ab81
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 121 deletions.
108 changes: 108 additions & 0 deletions maml/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch.nn.functional as F

from collections import namedtuple
from torchmeta.datasets import Omniglot, MiniImagenet
from torchmeta.toy import Sinusoid
from torchmeta.transforms import ClassSplitter, Categorical, Rotation
from torchvision.transforms import ToTensor, Resize, Compose

from maml.model import ModelConvOmniglot, ModelConvMiniImagenet, ModelMLPSinusoid
from maml.utils import ToTensor1D

Benchmark = namedtuple('Benchmark', 'meta_train_dataset meta_val_dataset '
'meta_test_dataset model loss_function')

def get_benchmark_by_name(name,
folder,
num_ways,
num_shots,
num_shots_test,
hidden_size=None):
dataset_transform = ClassSplitter(shuffle=True,
num_train_per_class=num_shots,
num_test_per_class=num_shots_test)
if name == 'sinusoid':
transform = ToTensor1D()

meta_train_dataset = Sinusoid(num_shots + num_shots_test,
num_tasks=1000000,
transform=transform,
target_transform=transform,
dataset_transform=dataset_transform)
meta_val_dataset = Sinusoid(num_shots + num_shots_test,
num_tasks=1000000,
transform=transform,
target_transform=transform,
dataset_transform=dataset_transform)
meta_test_dataset = Sinusoid(num_shots + num_shots_test,
num_tasks=1000000,
transform=transform,
target_transform=transform,
dataset_transform=dataset_transform)

model = ModelMLPSinusoid(hidden_sizes=[40, 40])
loss_function = F.mse_loss

elif name == 'omniglot':
class_augmentations = [Rotation([90, 180, 270])]
transform = Compose([Resize(28), ToTensor()])

meta_train_dataset = Omniglot(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_train=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform,
download=True)
meta_val_dataset = Omniglot(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_val=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform)
meta_test_dataset = Omniglot(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_test=True,
dataset_transform=dataset_transform)

model = ModelConvOmniglot(num_ways, hidden_size=hidden_size)
loss_function = F.cross_entropy

elif name == 'miniimagenet':
transform = Compose([Resize(84), ToTensor()])

meta_train_dataset = MiniImagenet(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_train=True,
dataset_transform=dataset_transform,
download=True)
meta_val_dataset = MiniImagenet(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_val=True,
dataset_transform=dataset_transform)
meta_test_dataset = MiniImagenet(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_test=True,
dataset_transform=dataset_transform)

model = ModelConvMiniImagenet(num_ways, hidden_size=hidden_size)
loss_function = F.cross_entropy

else:
raise NotImplementedError('Unknown dataset `{0}`.'.format(name))

return Benchmark(meta_train_dataset=meta_train_dataset,
meta_val_dataset=meta_val_dataset,
meta_test_dataset=meta_test_dataset,
model=model,
loss_function=loss_function)
65 changes: 19 additions & 46 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,9 @@
import json

from torchmeta.utils.data import BatchMetaDataLoader
from torchmeta.datasets import Omniglot, MiniImagenet
from torchmeta.toy import Sinusoid
from torchmeta.transforms import ClassSplitter, Categorical
from torchvision.transforms import ToTensor, Resize, Compose

from maml.model import ModelConvOmniglot, ModelConvMiniImagenet, ModelMLPSinusoid
from maml.datasets import get_benchmark_by_name
from maml.metalearners import ModelAgnosticMetaLearning
from maml.utils import ToTensor1D

def main(args):
with open(args.config, 'r') as f:
Expand All @@ -26,49 +21,27 @@ def main(args):
device = torch.device('cuda' if args.use_cuda
and torch.cuda.is_available() else 'cpu')

dataset_transform = ClassSplitter(shuffle=True,
num_train_per_class=config['num_shots'],
num_test_per_class=config['num_shots_test'])
if config['dataset'] == 'sinusoid':
transform = ToTensor1D()
meta_test_dataset = Sinusoid(config['num_shots'] + config['num_shots_test'],
num_tasks=1000000, transform=transform, target_transform=transform,
dataset_transform=dataset_transform)
model = ModelMLPSinusoid(hidden_sizes=[40, 40])
loss_function = F.mse_loss

elif config['dataset'] == 'omniglot':
transform = Compose([Resize(28), ToTensor()])
meta_test_dataset = Omniglot(config['folder'], transform=transform,
target_transform=Categorical(config['num_ways']),
num_classes_per_task=config['num_ways'], meta_test=True,
dataset_transform=dataset_transform, download=True)
model = ModelConvOmniglot(config['num_ways'],
hidden_size=config['hidden_size'])
loss_function = F.cross_entropy

elif config['dataset'] == 'miniimagenet':
transform = Compose([Resize(84), ToTensor()])
meta_test_dataset = MiniImagenet(config['folder'], transform=transform,
target_transform=Categorical(config['num_ways']),
num_classes_per_task=config['num_ways'], meta_test=True,
dataset_transform=dataset_transform, download=True)
model = ModelConvMiniImagenet(config['num_ways'],
benchmark = get_benchmark_by_name(config['dataset'],
config['folder'],
config['num_ways'],
config['num_shots'],
config['num_shots_test'],
hidden_size=config['hidden_size'])
loss_function = F.cross_entropy

else:
raise NotImplementedError('Unknown dataset `{0}`.'.format(config['dataset']))

with open(config['model_path'], 'rb') as f:
model.load_state_dict(torch.load(f, map_location=device))

meta_test_dataloader = BatchMetaDataLoader(meta_test_dataset,
batch_size=config['batch_size'], shuffle=True,
num_workers=args.num_workers, pin_memory=True)
metalearner = ModelAgnosticMetaLearning(model,
first_order=config['first_order'], num_adaptation_steps=config['num_steps'],
step_size=config['step_size'], loss_function=loss_function, device=device)
benchmark.model.load_state_dict(torch.load(f, map_location=device))

meta_test_dataloader = BatchMetaDataLoader(benchmark.meta_test_dataset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=args.num_workers,
pin_memory=True)
metalearner = ModelAgnosticMetaLearning(benchmark.model,
first_order=config['first_order'],
num_adaptation_steps=config['num_steps'],
step_size=config['step_size'],
loss_function=benchmark.loss_function,
device=device)

results = metalearner.evaluate(meta_test_dataloader,
max_batches=config['num_batches'],
Expand Down
108 changes: 33 additions & 75 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import torch
import torch.nn.functional as F
import math
import os
import time
import json
import logging

from torchmeta.utils.data import BatchMetaDataLoader
from torchmeta.datasets import Omniglot, MiniImagenet
from torchmeta.toy import Sinusoid
from torchmeta.transforms import ClassSplitter, Categorical, Rotation
from torchvision.transforms import ToTensor, Resize, Compose

from maml.model import ModelConvOmniglot, ModelConvMiniImagenet, ModelMLPSinusoid
from maml.datasets import get_benchmark_by_name
from maml.metalearners import ModelAgnosticMetaLearning
from maml.utils import ToTensor1D

def main(args):
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
Expand All @@ -39,79 +33,43 @@ def main(args):
logging.info('Saving configuration file in `{0}`'.format(
os.path.abspath(os.path.join(folder, 'config.json'))))

dataset_transform = ClassSplitter(shuffle=True,
num_train_per_class=args.num_shots,
num_test_per_class=args.num_shots_test)
class_augmentations = [Rotation([90, 180, 270])]
if args.dataset == 'sinusoid':
transform = ToTensor1D()

meta_train_dataset = Sinusoid(args.num_shots + args.num_shots_test,
num_tasks=1000000, transform=transform, target_transform=transform,
dataset_transform=dataset_transform)
meta_val_dataset = Sinusoid(args.num_shots + args.num_shots_test,
num_tasks=1000000, transform=transform, target_transform=transform,
dataset_transform=dataset_transform)

model = ModelMLPSinusoid(hidden_sizes=[40, 40])
loss_function = F.mse_loss

elif args.dataset == 'omniglot':
transform = Compose([Resize(28), ToTensor()])

meta_train_dataset = Omniglot(args.folder, transform=transform,
target_transform=Categorical(args.num_ways),
num_classes_per_task=args.num_ways, meta_train=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform, download=True)
meta_val_dataset = Omniglot(args.folder, transform=transform,
target_transform=Categorical(args.num_ways),
num_classes_per_task=args.num_ways, meta_val=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform)

model = ModelConvOmniglot(args.num_ways, hidden_size=args.hidden_size)
loss_function = F.cross_entropy

elif args.dataset == 'miniimagenet':
transform = Compose([Resize(84), ToTensor()])

meta_train_dataset = MiniImagenet(args.folder, transform=transform,
target_transform=Categorical(args.num_ways),
num_classes_per_task=args.num_ways, meta_train=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform, download=True)
meta_val_dataset = MiniImagenet(args.folder, transform=transform,
target_transform=Categorical(args.num_ways),
num_classes_per_task=args.num_ways, meta_val=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform)

model = ModelConvMiniImagenet(args.num_ways, hidden_size=args.hidden_size)
loss_function = F.cross_entropy

else:
raise NotImplementedError('Unknown dataset `{0}`.'.format(args.dataset))

meta_train_dataloader = BatchMetaDataLoader(meta_train_dataset,
batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
pin_memory=True)
meta_val_dataloader = BatchMetaDataLoader(meta_val_dataset,
batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
pin_memory=True)

meta_optimizer = torch.optim.Adam(model.parameters(), lr=args.meta_lr)
metalearner = ModelAgnosticMetaLearning(model, meta_optimizer,
first_order=args.first_order, num_adaptation_steps=args.num_steps,
step_size=args.step_size, loss_function=loss_function, device=device)
benchmark = get_benchmark_by_name(args.dataset,
args.folder,
args.num_ways,
args.num_shots,
args.num_shots_test,
hidden_size=args.hidden_size)

meta_train_dataloader = BatchMetaDataLoader(benchmark.meta_train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True)
meta_val_dataloader = BatchMetaDataLoader(benchmark.meta_val_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True)

meta_optimizer = torch.optim.Adam(benchmark.model.parameters(), lr=args.meta_lr)
metalearner = ModelAgnosticMetaLearning(benchmark.model,
meta_optimizer,
first_order=args.first_order,
num_adaptation_steps=args.num_steps,
step_size=args.step_size,
loss_function=benchmark.loss_function,
device=device)

best_value = None

# Training loop
epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 + int(math.log10(args.num_epochs)))
for epoch in range(args.num_epochs):
metalearner.train(meta_train_dataloader, max_batches=args.num_batches,
verbose=args.verbose, desc='Training', leave=False)
metalearner.train(meta_train_dataloader,
max_batches=args.num_batches,
verbose=args.verbose,
desc='Training',
leave=False)
results = metalearner.evaluate(meta_val_dataloader,
max_batches=args.num_batches,
verbose=args.verbose,
Expand All @@ -130,7 +88,7 @@ def main(args):

if save_model and (args.output_folder is not None):
with open(args.model_path, 'wb') as f:
torch.save(model.state_dict(), f)
torch.save(benchmark.model.state_dict(), f)

if hasattr(meta_train_dataset, 'close'):
meta_train_dataset.close()
Expand Down

0 comments on commit 3c0ab81

Please sign in to comment.