diff --git a/benchmarks/_template/main.py b/benchmarks/_template/main.py index 9169a0f07..99f0f0adc 100644 --- a/benchmarks/_template/main.py +++ b/benchmarks/_template/main.py @@ -6,14 +6,31 @@ import time -import voir -from giving import give +import torchcompat.core as accelerator +from benchmate.observer import BenchObserver def main(): - for i in voir.iterate("train", range(10000), report_batch=True, batch_size=64): - give(loss=1 / (i + 1)) - time.sleep(0.1) + device = accelerator.fetch_device(0) # <= This is your cuda device + + observer = BenchObserver(batch_size_fn=lambda batch: 1) + # optimizer = observer.optimizer(optimizer) + # criterion = observer.criterion(criterion) + + dataloader = [1, 2, 3, 4] + + for epoch in range(10): + for i in observer.iterate(dataloader): + # avoid .item() + # avoid torch.cuda; use accelerator from torchcompat instead + # avoid torch.cuda.synchronize or accelerator.synchronize + + # y = model(i) + # loss = criterion(y) + # loss.backward() + # optimizer.step() + + time.sleep(0.1) if __name__ == "__main__": diff --git a/benchmate/benchmate/metrics.py b/benchmate/benchmate/metrics.py index b6ca483c7..a73c8e03d 100644 --- a/benchmate/benchmate/metrics.py +++ b/benchmate/benchmate/metrics.py @@ -132,6 +132,15 @@ def elapsed(self): return self._start.elapsed_time(self._end) +def default_event(): + try: + import torchcompat.core as accelerator + return accelerator.Event + except: + print("Could not find a device timer") + return CPUTimer() + + class TimedIterator: """Time the body of a loop, ignoring the time it took to initialize the iterator.` The timings are measured using `torch.cuda.Event` to avoid explicit sync. @@ -197,7 +206,7 @@ def with_give(cls, *args, push=None, **kwargs): def __init__( self, loader, - event_fn, + event_fn=default_event(), rank=0, push=file_push(), device=None, diff --git a/benchmate/benchmate/observer.py b/benchmate/benchmate/observer.py index 5ead66a5b..aae642511 100644 --- a/benchmate/benchmate/observer.py +++ b/benchmate/benchmate/observer.py @@ -59,6 +59,9 @@ def override_return_value(self, function, override): else: raise RuntimeError("Not running through voir") + def iterate(self, iterator): + return self.loader(loader) + def loader(self, loader): """Wrap a dataloader or an iterable which enable accurate measuring of time spent in the loop's body""" self.wrapped = TimedIterator( diff --git a/docs/new_benchmarks.rst b/docs/new_benchmarks.rst index 26356f8c4..058b99c0d 100644 --- a/docs/new_benchmarks.rst +++ b/docs/new_benchmarks.rst @@ -78,15 +78,28 @@ The template ``main.py`` demonstrates a simple loop that you can adapt to any sc .. code-block:: python + def main(): - for i in voir.iterate("train", range(100), report_batch=True, batch_size=64): - give(loss=1/(i + 1)) - time.sleep(0.1) - -* Wrap the training loop with ``voir.iterate``. - * ``report_batch=True`` triggers the computation of the number of training samples per second. - * Set ``batch_size`` to the batch_size. milabench can also figure it out automatically if you are iterating over the input batches (it will use the first number in the tensor's shape). -* ``give(loss=loss.item())`` will forward the value of the loss to milabench. Make sure the value is a plain Python ``float``. + observer = BenchObserver(batch_size_fn=lambda batch: 1) + criterion = observer.criterion(criterion) + optimizer = observer.optimizer(optimizer) + + for epoch in range(10): + for i in observer.iterate(dataloader): + # ... + time.sleep(0.1) + +* Create a new bench observer, this class is used to time the benchmark and measure batch times. + * Set ``batch_size_fn`` to provide a function to compute the right batch size given a batch. +* ``observer.criterion(criterion)`` will wrap the criterion function so the loss will be reported automatically. +* ``observer.optimizer(optimizer)`` will wrap the optimizer so device that need special handling can have their logic executed there +* Wrap the batch loop with ``observer.iterate``, it will take care of timing the body of the loop and handle early stopping if necessary + +.. note:: + + Avoid calls to ``.item()``, ``torch.cuda`` and ``torch.cuda.synchronize()``. + To access ``cuda`` related features use ``accelerator`` from torchcompat. + ``accelerator`` is a light wrapper around ``torch.cuda`` to allow a wider range of devices to be used. If the script takes command line arguments, you can parse them however you like, for example with ``argparse.ArgumentParser``. Then, you can add an ``argv`` section in ``dev.yaml``, just like this: