Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resnet 50 #95

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions example/synchronous_resnet/resnet_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from __future__ import absolute_import, print_function
import os
import sys
import time

import numpy

sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
from platoon.channel import Controller


class ResNetController(Controller):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you can't reuse the default controler? I don't like the idea to ask every user to make there own controler. If for this example we need a new controler, do we need to update the class Controler to meat the current need?

Copy link
Contributor Author

@olimastro olimastro Jun 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly copied everything from the lstm example and work from there. You are right though that if I do a diff between the lstm_controller and resnet_controller there are no differences except the names (the time request is just for the timing and won't end up in final release). There could just be a general controller used by both examples. I do have to implement the method handle_control and subclass the general Controler class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, just reuse the LSTMController directly. We can see what we do later for a general controler

"""
This multi-process controller implements patience-based early-stopping SGD
"""

def __init__(self, seed, patience, default_args):
"""
Initialize the ResNetController

Parameters
----------
max_mb : int
Max number of minibatches to train on.
patience: : int
Training stops when this many minibatches have been trained on
without any reported improvement.
valid_freq : int
Number of minibatches to train on between every monitoring step.
default_args : dict
Arguments of default class Controller
"""
super(ResNetController, self).__init__(**default_args)
self.nb_worker = len(self._devices)
# map ids to members of range(nb_worker)
self.worker_ids_dict = dict(zip(self._workers, [i for i in range(len(self._workers))]))

self.patience = patience
self.seed = seed

self.valid_history_errs = [[None for i in range(self.nb_worker)]]
self.test_history_errs = [[None for i in range(self.nb_worker)]]
self.bad_counter = 0
self._epoch = 0
self.best_dict = dict(best__epoch=-1, best_valid=numpy.inf)


def handle_control(self, req, worker_id, req_info):
"""
Handles a control_request received from a worker

Parameters
----------
req : str or dict
Control request received from a worker.
The control request can be one of the following
1) "next" : request by a worker to be informed of its next action
to perform. The answers from the server can be 'train' (the
worker should keep training on its training data), 'valid' (the
worker should perform monitoring on its validation set and test
set) or 'stop' (the worker should stop training).
2) dict of format {"done":N} : used by a worker to inform the
server that is has performed N more training iterations and
synced its parameters. The server will respond 'stop' if the
maximum number of training minibatches has been reached.
3) dict of format {"valid_err":x, "test_err":x2} : used by a worker
to inform the server that it has performed a monitoring step
and obtained the included errors on the monitoring datasets.
The server will respond "best" if this is the best reported
validation error so far, otherwise it will respond 'stop' if
the patience has been exceeded.
"""
control_response = ""
worker_id = self.worker_ids_dict[worker_id]

if req == 'pred_errors':
if self.valid_history_errs[self._epoch][worker_id] is not None:
# if a worker tries to add a valid error where there is no None
# it means it tries to index after or before current _epoch
raise RuntimeError('Worker got out of synch!')
self.valid_history_errs[self._epoch][worker_id] = req_info['valid_err']
self.test_history_errs[self._epoch][worker_id] = req_info['test_err']

if not any([i is None for i in self.valid_history_errs[self._epoch]]):
print('Epoch %d is done'%req_info['epoch'])
valid_err = sum(self.valid_history_errs[self._epoch]) / float(self.nb_worker)

if valid_err <= self.best_dict['best_valid']:
self.best_dict['best_epoch'] = self._epoch
self.best_dict['best_valid'] = valid_err
self.bad_counter = 0
control_response = 'best'
print("Best error valid:", valid_err)
else:
self.bad_counter += 1
self.valid_history_errs += [[None for i in range(self.nb_worker)]]
self.test_history_errs += [[None for i in range(self.nb_worker)]]
self._epoch += 1

elif req == 'splits':
# the controller never loads the dataset but the worker doesn't
# know how many workers there are
train_len = req_info['train_len'] // self.nb_worker
valid_len = req_info['valid_len'] // self.nb_worker
test_len = req_info['test_len'] // self.nb_worker
splits = dict(train_splits=[train_len * worker_id, train_len * (worker_id + 1)],
valid_splits=[valid_len * worker_id, valid_len * (worker_id + 1)],
test_splits=[test_len * worker_id, test_len * (worker_id + 1)])
control_response = splits

# kind of when the training start but not really
self.start_time = time.time()

elif req == 'seed':
control_response = self.seed

if self.bad_counter > self.patience:
print("Early stopping!")
end_time = time.time() - self.start_time
# should terminate with best printing and best dumping of params
# and then close everything
print("Best error valid:", self.best_dict['best_valid'])
test_err = sum(self.test_history_errs[self.best_dict['best_epoch']]) / \
float(self.nb_worker)
print("Best error test:", test_err)
print( ("Training took %.1fs" % (end_time)), file=sys.stderr)
control_response = 'stop'

return control_response


def resnet_control(saveFreq=1110, saveto=None):
parser = Controller.default_parser()
parser.add_argument('--seed', default=1234, type=int,
required=False, help='Maximum mini-batches to train upon in total.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the help

parser.add_argument('--patience', default=10, type=int, required=False,
help='Maximum patience when failing to get better validation results.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it is in mini-batch unit. Can you specify that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the original lstm example it is from validation epoch actually, so since the validation is not done after every training epoch then this number has to be not too large. Could make the help more useful indeed.

args = parser.parse_args()

l = ResNetController(seed=args.seed,
patience=args.patience,
default_args=Controller.default_arguments(args))

print("Controller is ready")
return l.serve()

if __name__ == '__main__':
rcode = resnet_control()
if rcode != 0:
sys.exit(rcode)
Loading