Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Delaunay authored Jun 21, 2024
1 parent fb56c6f commit 0d721a5
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions benchmarks/_template/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,23 @@

import time

import voir
from giving import give
import torchcompat.core as accelerator
from benchmate.observer import BenchObserver


def main():
for i in voir.iterate("train", range(10000), report_batch=True, batch_size=64):
give(loss=1 / (i + 1))
time.sleep(0.1)
device = accelerator.fetch_device(0) # <= This is your cuda device

observer = BenchObserver()
dataloader = [1, 2, 3, 4]

for epoch in range(10):
for i in observer.iterate(dataloader):
# avoid .item()
# avoid torch.cuda; use accelerator from torchcompat instead
# avoid torch.cuda.synchronize or accelerator.synchronize
observer.record_loss(loss=1 / (i + 1))
time.sleep(0.1)


if __name__ == "__main__":
Expand Down

0 comments on commit 0d721a5

Please sign in to comment.