Skip to content

Commit

Permalink
force HCCL when hpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed May 21, 2024
1 parent 1ce5ff4 commit 49a6458
Showing 1 changed file with 26 additions and 27 deletions.
53 changes: 26 additions & 27 deletions benchmarks/accelerate_opt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ class CustomInitProcessGroupKwargs(InitProcessGroupKwargs):
rank=int(os.environ["RANK"]),
world_size=int(os.environ["WORLD_SIZE"]),
)
print(init_process_group_kwargs.backend)

# Accelerator SUCK, it is impossible to make it use hccl
# We can bypass Accelerator logic by initializing the group ourselves
if acc.device_type == "hpu":
acc.init_process_group(
init_method=f"tcp://{MASTER_ADDR}:{MASTER_PORT}",
timeout=timedelta(seconds=60),
rank=int(os.environ["RANK"]),
world_size=int(os.environ["WORLD_SIZE"]),
)

accelerator = Accelerator(kwargs_handlers=[init_process_group_kwargs])
else:
accelerator = Accelerator()
Expand Down Expand Up @@ -377,14 +389,26 @@ def group_texts(examples):
starting_epoch = 0
last_log_time = time.time()

from voir.wrapper import Wrapper
wrapper = Wrapper(
event_fn=acc.Event,
earlystop=30,
rank=int(os.environ["RANK"]),
device=acc.fetch_device(int(os.environ["RANK"])),
stdout=True,
)
loader = wrapper.loader(train_dataloader)

for epoch in range(starting_epoch, num_train_epochs):
model.train()
for step, batch in enumerate(train_dataloader):
for step, batch in enumerate(loader):
outputs = model(**batch)
loss = outputs.loss
loss = loss / gradient_accumulation_steps
if accelerator.is_main_process:
mblog({"task": "train", "loss": loss.detach().item()})
loader.add_loss(loss)
# mblog({"task": "train", "loss": loss.detach().item()})

accelerator.backward(loss)

if (step + 1) % gradient_accumulation_steps == 0 or step == len(
Expand All @@ -399,31 +423,6 @@ def group_texts(examples):
if not accelerator.optimizer_step_was_skipped:
completed_steps += 1

log_interval = 3
if accelerator.is_main_process and completed_steps % log_interval == 0:
acc.synchronize()

if completed_steps == 0:
last_log_time = time.time()
else:
seconds_since_last_log = time.time() - last_log_time

n_samples_since_last_log = log_interval * total_batch_size

throughput_samples_per_sec = (
n_samples_since_last_log / seconds_since_last_log
)

mblog(
{
"task": "train",
"rate": throughput_samples_per_sec,
"units": "items/s",
}
)

last_log_time = time.time()

if completed_steps >= max_train_steps:
break

Expand Down

0 comments on commit 49a6458

Please sign in to comment.