Skip to content

Commit

Permalink
Update Template example
Browse files Browse the repository at this point in the history
  • Loading branch information
Delaunay authored Jun 21, 2024
1 parent 64e692e commit c385ccc
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 14 deletions.
27 changes: 22 additions & 5 deletions benchmarks/_template/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,31 @@

import time

import voir
from giving import give
import torchcompat.core as accelerator
from benchmate.observer import BenchObserver


def main():
for i in voir.iterate("train", range(10000), report_batch=True, batch_size=64):
give(loss=1 / (i + 1))
time.sleep(0.1)
device = accelerator.fetch_device(0) # <= This is your cuda device

observer = BenchObserver(batch_size_fn=lambda batch: 1)
# optimizer = observer.optimizer(optimizer)
# criterion = observer.criterion(criterion)

dataloader = [1, 2, 3, 4]

for epoch in range(10):
for i in observer.iterate(dataloader):
# avoid .item()
# avoid torch.cuda; use accelerator from torchcompat instead
# avoid torch.cuda.synchronize or accelerator.synchronize

# y = model(i)
# loss = criterion(y)
# loss.backward()
# optimizer.step()

time.sleep(0.1)


if __name__ == "__main__":
Expand Down
11 changes: 10 additions & 1 deletion benchmate/benchmate/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ def elapsed(self):
return self._start.elapsed_time(self._end)


def default_event():
try:
import torchcompat.core as accelerator
return accelerator.Event
except:
print("Could not find a device timer")
return CPUTimer()


class TimedIterator:
"""Time the body of a loop, ignoring the time it took to initialize the iterator.`
The timings are measured using `torch.cuda.Event` to avoid explicit sync.
Expand Down Expand Up @@ -197,7 +206,7 @@ def with_give(cls, *args, push=None, **kwargs):
def __init__(
self,
loader,
event_fn,
event_fn=default_event(),
rank=0,
push=file_push(),
device=None,
Expand Down
3 changes: 3 additions & 0 deletions benchmate/benchmate/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def override_return_value(self, function, override):
else:
raise RuntimeError("Not running through voir")

def iterate(self, iterator):
return self.loader(loader)

def loader(self, loader):
"""Wrap a dataloader or an iterable which enable accurate measuring of time spent in the loop's body"""
self.wrapped = TimedIterator(
Expand Down
29 changes: 21 additions & 8 deletions docs/new_benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,28 @@ The template ``main.py`` demonstrates a simple loop that you can adapt to any sc

.. code-block:: python
def main():
for i in voir.iterate("train", range(100), report_batch=True, batch_size=64):
give(loss=1/(i + 1))
time.sleep(0.1)
* Wrap the training loop with ``voir.iterate``.
* ``report_batch=True`` triggers the computation of the number of training samples per second.
* Set ``batch_size`` to the batch_size. milabench can also figure it out automatically if you are iterating over the input batches (it will use the first number in the tensor's shape).
* ``give(loss=loss.item())`` will forward the value of the loss to milabench. Make sure the value is a plain Python ``float``.
observer = BenchObserver(batch_size_fn=lambda batch: 1)
criterion = observer.criterion(criterion)
optimizer = observer.optimizer(optimizer)
for epoch in range(10):
for i in observer.iterate(dataloader):
# ...
time.sleep(0.1)
* Create a new bench observer, this class is used to time the benchmark and measure batch times.
* Set ``batch_size_fn`` to provide a function to compute the right batch size given a batch.
* ``observer.criterion(criterion)`` will wrap the criterion function so the loss will be reported automatically.
* ``observer.optimizer(optimizer)`` will wrap the optimizer so device that need special handling can have their logic executed there
* Wrap the batch loop with ``observer.iterate``, it will take care of timing the body of the loop and handle early stopping if necessary

.. note::

Avoid calls to ``.item()``, ``torch.cuda`` and ``torch.cuda.synchronize()``.
To access ``cuda`` related features use ``accelerator`` from torchcompat.
``accelerator`` is a light wrapper around ``torch.cuda`` to allow a wider range of devices to be used.

If the script takes command line arguments, you can parse them however you like, for example with ``argparse.ArgumentParser``. Then, you can add an ``argv`` section in ``dev.yaml``, just like this:

Expand Down

0 comments on commit c385ccc

Please sign in to comment.