Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Jun 25, 2024
1 parent 94b27a7 commit e0d6731
Show file tree
Hide file tree
Showing 9 changed files with 356 additions and 161 deletions.
2 changes: 1 addition & 1 deletion benchmarks/accelerate_opt/benchfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AccelerateBenchmark(Package):

def make_env(self):
env = super().make_env()
value = resolve_placeholder(pack, "--cpus_per_gpu")
value = resolve_placeholder(self, "--cpus_per_gpu")
env["OMP_NUM_THREADS"] = str(value)
return env

Expand Down
5 changes: 1 addition & 4 deletions benchmate/benchmate/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,7 @@ def pytorch(folder, batch_size, num_workers, distributed=False, epochs=60):
def synthetic(model, batch_size, fixed_batch):
return SyntheticData(
tensors=generate_tensor_classification(
model,
batch_size,
(3, 244, 244),
device=accelerator.fetch_device(0)
model, batch_size, (3, 244, 244), device=accelerator.fetch_device(0)
),
n=1000,
fixed_batch=fixed_batch,
Expand Down
4 changes: 2 additions & 2 deletions benchmate/benchmate/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def monitor_fn():
}
for gpu in get_gpu_info()["gpus"].values()
}
mblog({"task": "main", "gpudata": data})
mblog({"task": "train", "gpudata": data})

monitor_fn()
monitor = Monitor(3, monitor_fn)
Expand Down Expand Up @@ -74,7 +74,7 @@ def monitor_fn():
}
for gpu in get_gpu_info()["gpus"].values()
}
return {"task": "main", "gpudata": data, "time": time.time(), "units": "s"}
return {"task": "train", "gpudata": data, "time": time.time(), "units": "s"}

monitor = CustomMonitor(0.5, monitor_fn)

Expand Down
8 changes: 7 additions & 1 deletion benchmate/benchmate/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ class BenchObserver:
"""

def __init__(
self, *args, backward_callback=None, step_callback=None, stdout=False, rank=None, **kwargs
self,
*args,
backward_callback=None,
step_callback=None,
stdout=False,
rank=None,
**kwargs,
):
self.wrapped = None
self.args = args
Expand Down
Loading

0 comments on commit e0d6731

Please sign in to comment.