From 396c96357b8cf3ef022aed4e1a073ff2a2d80694 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 25 Sep 2024 23:55:05 -0400 Subject: [PATCH] update mlperf bert scripts (#6755) removed DISABLE_DROPOUT=1. updated BS to 54 that works on tinyboxes with dropouts. used bert's sparse_categorical_crossentropy that takes Tensor ignore_index in accuracy method --- .../implementations/tinybox_green/dev_beam.sh | 5 +---- .../implementations/tinybox_green/dev_run.sh | 5 +---- .../tinybox_green/run_and_time.sh | 5 +---- .../implementations/tinybox_red/dev_beam.sh | 5 +---- .../implementations/tinybox_red/dev_run.sh | 5 +---- .../tinybox_red/run_and_time.sh | 5 +---- extra/models/bert.py | 18 +++++++++--------- tinygrad/tensor.py | 2 +- 8 files changed, 16 insertions(+), 34 deletions(-) diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh index 6c21de0d6c690..368d2cc9b4fc9 100755 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh @@ -2,14 +2,11 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6 +export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 export BASEDIR="/raid/datasets/wiki" -echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN" -export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run. - export BENCHMARK=10 DEBUG=2 python3 examples/mlperf/model_train.py diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh index 09bbdc1a12623..672cd2f24ef57 100755 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh @@ -2,14 +2,11 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6 +export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 export BASEDIR="/raid/datasets/wiki" -echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN" -export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run. - export WANDB=1 python3 examples/mlperf/model_train.py \ No newline at end of file diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh index 5cce3e133e27a..7c3a928242ce1 100755 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh @@ -3,14 +3,11 @@ export PYTHONPATH="." export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_green" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6 +export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 export BASEDIR="/raid/datasets/wiki" -echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN" -export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run. - # pip install -e ".[mlperf]" export LOGMLPERF=1 diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh index f887ee4ff2b77..368d2cc9b4fc9 100644 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh @@ -2,14 +2,11 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=84 EVAL_BS=6 +export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 export BASEDIR="/raid/datasets/wiki" -echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN" -export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run. - export BENCHMARK=10 DEBUG=2 python3 examples/mlperf/model_train.py diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh index a725c167eb880..672cd2f24ef57 100644 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh @@ -2,14 +2,11 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=84 EVAL_BS=6 +export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 export BASEDIR="/raid/datasets/wiki" -echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN" -export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run. - export WANDB=1 python3 examples/mlperf/model_train.py \ No newline at end of file diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh index e5dd33caa5d98..e6036c027b4a7 100644 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh @@ -3,14 +3,11 @@ export PYTHONPATH="." export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_red" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=84 EVAL_BS=6 +export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 export BASEDIR="/raid/datasets/wiki" -echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN" -export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run. - # pip install -e ".[mlperf]" export LOGMLPERF=1 diff --git a/extra/models/bert.py b/extra/models/bert.py index 4891888bd8d9e..4503ed0aed2e2 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -49,15 +49,15 @@ def __call__(self, input_ids:Tensor, attention_mask:Tensor, masked_lm_positions: output = self.bert(input_ids, attention_mask, token_type_ids) return self.cls(output, masked_lm_positions) + # Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315 + def sparse_categorical_crossentropy(self, predictions:Tensor, labels:Tensor, ignore_index=-1): + log_probs, loss_mask = predictions.log_softmax(), (labels != ignore_index) + y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1]) + y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1]) + return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero + def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): - # Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315 - def sparse_categorical_crossentropy(predictions:Tensor, labels:Tensor, ignore_index=-1): - log_probs, loss_mask = predictions.log_softmax(), (labels != ignore_index) - y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1]) - y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1]) - return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero - - masked_lm_loss = sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights) + masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights) next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels) return masked_lm_loss + next_sentence_loss @@ -66,7 +66,7 @@ def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, mas valid = masked_lm_ids != 0 masked_lm_predictions = prediction_logits.log_softmax().argmax(-1) masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid - masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights) + masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights) seq_relationship_predictions = seq_relationship_logits.log_softmax().argmax(-1) seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 635cb7c7aee78..2da7c7f779724 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3141,7 +3141,7 @@ def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean") -> """ return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction) - def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor: + def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor: """ Computes the sparse categorical cross-entropy loss between `self` and `Y`.