Skip to content

Commit

Permalink
Last intel modif
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed May 14, 2024
1 parent 028b228 commit be3681c
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
6 changes: 5 additions & 1 deletion benchmarks/llama/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,8 @@ def main():


if __name__ == "__main__":
main()
try:
main()
except Exception as err:
# Habana likes to eat exceptions
print(err)
15 changes: 10 additions & 5 deletions benchmarks/torchvision/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ def train_epoch(model, criterion, optimizer, loader, device, dtype, scaler=None)
def toiterator(loader):
with timeit("loader"):
return iter(loader)

for inp, target in timeiterator(voir.iterate("train", toiterator(loader), True)):

iterator = timeiterator(voir.iterate("train", toiterator(loader), True))

for inp, target in iterator:

with timeit("batch"):
inp = inp.to(device, dtype=dtype)
Expand Down Expand Up @@ -150,7 +152,10 @@ def toiterator(loader):
with given() as gv:
for epoch in voir.iterate("main", range(args.epochs)):
with timeit("epoch"):
for inp, target in timeiterator(voir.iterate("train", toiterator(loader), True)):

iterator = timeiterator(voir.iterate("train", toiterator(loader), True))

for inp, target in iterator:
with timeit("batch"):
inp = inp.to(device, dtype=dtype)
target = target.to(device)
Expand Down Expand Up @@ -251,7 +256,7 @@ def _main():
if args.iobench:
iobench(args)
else:
trainbench()
trainbench(args)

def trainbench(args):
if args.fixed_batch:
Expand All @@ -277,7 +282,7 @@ def trainbench(args):

optimizer = torch.optim.SGD(model.parameters(), args.lr)

model, optimizer = accelerator.optimizer(model, optimizer=optimizer, dtype=float_dtype(args.precision))
model, optimizer = accelerator.optimize(model, optimizer=optimizer, dtype=float_dtype(args.precision))

if args.data:
train_loader = dataloader(args)
Expand Down
5 changes: 4 additions & 1 deletion benchmarks/torchvision_ddp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,16 @@ def image_transforms():
return data_transforms

def prepare_dataloader(dataset: Dataset, args):
dsampler = DistributedSampler(dataset)
# next(iter(dsampler))

return DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers if not args.noio else 0,
pin_memory=not args.noio,
shuffle=False,
sampler=DistributedSampler(dataset)
sampler=dsampler
)

class FakeDataset:
Expand Down
1 change: 1 addition & 0 deletions config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ _timm:
--amp-dtype: bfloat16
## FIXME
--device: hpu
--dist-backend: hccl

_sb3:
inherits: _defaults
Expand Down

0 comments on commit be3681c

Please sign in to comment.