Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
husencd committed Jun 22, 2018
1 parent aac27b1 commit 91dbcff
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 39 deletions.
25 changes: 12 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import os
import json

from args import parse_args
from model import get_model_param
from data import Driver
Expand Down Expand Up @@ -36,6 +37,7 @@ def main():
torch.manual_seed(args.manual_seed)

args.use_cuda = args.use_cuda and torch.cuda.is_available()

device = torch.device("cuda" if args.use_cuda else "cpu")

# create model
Expand All @@ -48,11 +50,7 @@ def main():

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, threshold=0.001, patience=args.lr_patience)

lr_mult = []
Expand Down Expand Up @@ -83,6 +81,7 @@ def main():
train_batch_logger = Logger(
os.path.join(args.result_path, 'train_batch.log'),
['epoch', 'batch', 'iter', 'loss', 'top1', 'top3', 'lr'])

if args.val:
val_dataset = Driver(root=args.data_path, train=False, test=True)
val_loader = DataLoader(
Expand All @@ -100,25 +99,25 @@ def main():
if args.train:
adjust_learning_rate(optimizer, epoch, lr_mult, args)
train_epoch(epoch, train_loader, model, criterion, optimizer, args, device, train_logger, train_batch_logger, vis)
print()
print('\n')

if args.val:
val_loss, val_prec1 = val_epoch(epoch, val_loader, model, criterion, args, device, val_logger, vis)
print()
print('\n')
# remember best prec@1 and save checkpoint
if val_prec1 > best_prec1:
best_prec1 = val_prec1
best_epoch = epoch
print('=> Saving current best model...\n')
save_file_path = os.path.join(args.result_path, 'save_best_{}_{}.pth'.format(args.arch, epoch))
state = {
'epoch': best_epoch,
checkpoint = {
'arch': args.arch,
'epoch': best_epoch,
'best_prec1': best_prec1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(state, save_file_path)
torch.save(checkpoint, save_file_path)

# if args.train and args.val:
# scheduler.step(val_loss)
Expand All @@ -135,7 +134,7 @@ def main():
saved_model_path = os.path.join(args.result_path, 'save_best_{}_{}.pth'.format(args.arch, best_epoch))
print("Using '{}' for test...".format(saved_model_path))
checkpoint = torch.load(saved_model_path)
model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['model'])
test.test(test_loader, model, args, device)


Expand Down
20 changes: 0 additions & 20 deletions models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,23 +263,3 @@ def get_fine_tuning_parameters(model, ft_begin_index=0, lr_mult1=0.1, lr_mult2=1
model = model.to(device)
y = model(x)
print(torch.nn.functional.softmax(y, dim=1))

# parameters1 = get_fine_tuning_parameters(model, 0)
# parameters2 = list(parameters1)
# parameters3 = get_fine_tuning_parameters(model, 1)
# print(type(parameters1))
# print(type(parameters3))
# print(len(parameters2))
# print(len(parameters3))

# print(type(parameters2[0]))
# print(type(parameters3[0]['params']))

"""
<class 'generator'>
<class 'list'>
62
62
<class 'torch.nn.parameter.Parameter'>
<class 'torch.nn.parameter.Parameter'>
"""
11 changes: 5 additions & 6 deletions pretrained_models/download.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

echo "Download resnet models pretrained on ImageNet..."

wget https://download.pytorch.org/models/resnet18-5c106cde.pth
wget https://download.pytorch.org/models/resnet34-333f7ec4.pth
wget https://download.pytorch.org/models/resnet50-19c8e357.pth
wget https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
wget https://download.pytorch.org/models/resnet152-b121ed2d.pth

wget -N https://download.pytorch.org/models/resnet18-5c106cde.pth
wget -N https://download.pytorch.org/models/resnet34-333f7ec4.pth
wget -N https://download.pytorch.org/models/resnet50-19c8e357.pth
wget -N https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
wget -N https://download.pytorch.org/models/resnet152-b121ed2d.pth

0 comments on commit 91dbcff

Please sign in to comment.