Skip to content

Commit

Permalink
Merge branch 'master' into dropout-contiguous
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz authored Oct 6, 2024
2 parents 9e7a87f + 9eb6eef commit 0371c40
Show file tree
Hide file tree
Showing 24 changed files with 384 additions and 291 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Unit Tests
env:
# increment this when downloads substantially change to avoid the internet
DOWNLOAD_CACHE_VERSION: '5'
DOWNLOAD_CACHE_VERSION: '6'
RUN_PROCESS_REPLAY: 1
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PYTHONPATH: .
Expand Down
8 changes: 8 additions & 0 deletions docs/mnist.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ timeit.repeat(step, repeat=5, number=1)

So around 75 ms on T4 colab.

If you want to see a breakdown of the time by kernel:

```python
from tinygrad import GlobalCounters, Context
GlobalCounters.reset()
with Context(DEBUG=2): step()
```

### Why so slow?

Unlike PyTorch, tinygrad isn't designed to be fast like that. While 75 ms for one step is plenty fast for debugging, it's not great for training. Here, we introduce the first quintessentially tinygrad concept, the `TinyJit`.
Expand Down
1 change: 1 addition & 0 deletions docs/tensor/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
::: tinygrad.Tensor.softmax
::: tinygrad.Tensor.log_softmax
::: tinygrad.Tensor.logsumexp
::: tinygrad.Tensor.logcumsumexp
::: tinygrad.Tensor.argmax
::: tinygrad.Tensor.argmin

Expand Down
11 changes: 9 additions & 2 deletions examples/mlperf/model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def train_resnet():
if getenv("LOGMLPERF"):
from mlperf_logging import mllog
import mlperf_logging.mllog.constants as mllog_constants
mllog.config(filename=f"result_{seed}.txt")
mllog.config(filename=f"result_resnet_{seed}.txt")
mllog.config(root_dir=Path(__file__).parents[3].as_posix()) # truncate to log this. "file": "tinygrad/examples/mlperf/model_train.py"
MLLOGGER = mllog.get_mllogger()
if INITMLPERF:
Expand Down Expand Up @@ -621,7 +621,7 @@ def train_bert():
from mlperf_logging import mllog
import mlperf_logging.mllog.constants as mllog_constants

mllog.config(filename="bert.log")
mllog.config(filename=f"result_bert_{seed}.log")
mllog.config(root_dir=Path(__file__).parents[3].as_posix())
MLLOGGER = mllog.get_mllogger()
MLLOGGER.logger.propagate = False
Expand Down Expand Up @@ -752,7 +752,12 @@ def train_bert():
else:
i, train_data = start_step, get_fake_data_bert(GPUS, BS)

epoch_started = False
while train_data is not None and i < train_steps and not achieved:
if not epoch_started and MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i+1, metadata=dict(epoch_num=i+1))
epoch_started = True

Tensor.training = True
BEAM.value = TRAIN_BEAM
st = time.perf_counter()
Expand Down Expand Up @@ -801,6 +806,8 @@ def train_bert():
# ** eval loop **
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK):
if MLLOGGER and RUNMLPERF:
epoch_started = False
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=i+1, metadata=dict(epoch_num=i+1))
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": 1, "epoch_count": 1, "step_num": i})
if getenv("RESET_STEP", 1): train_step_bert.reset()
eval_lm_losses = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6

export BEAM=4
export BEAM=3
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6

export BEAM=4
export BEAM=3
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export MODEL="bert"
export SUBMISSION_PLATFORM="tinybox_red"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6

export BEAM=4
export BEAM=3
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

Expand Down
7 changes: 3 additions & 4 deletions extra/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __call__(self, input_ids:Tensor, attention_mask:Tensor, masked_lm_positions:

# Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315
def sparse_categorical_crossentropy(self, predictions:Tensor, labels:Tensor, ignore_index=-1):
log_probs, loss_mask = predictions.log_softmax(), (labels != ignore_index)
log_probs, loss_mask = predictions.log_softmax(dtype=dtypes.float), (labels != ignore_index)
y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1])
return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero
Expand All @@ -62,13 +62,12 @@ def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_
return masked_lm_loss + next_sentence_loss

def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):

valid = masked_lm_ids != 0
masked_lm_predictions = prediction_logits.log_softmax().argmax(-1)
masked_lm_predictions = prediction_logits.log_softmax(dtype=dtypes.float).argmax(-1)
masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid
masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)

seq_relationship_predictions = seq_relationship_logits.log_softmax().argmax(-1)
seq_relationship_predictions = seq_relationship_logits.log_softmax(dtype=dtypes.float).argmax(-1)
seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels)
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)

Expand Down
2 changes: 1 addition & 1 deletion extra/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def prepare_image(self, image:Image.Image) -> Tensor:
top = (h - SIZE) // 2
image = image.crop((0, SIZE, top, top+SIZE))

x = Tensor(np.array(image.convert('RGB')))
x = Tensor(np.array(image.convert('RGB')), device=self.std.device)
x = x.permute(2, 0, 1).cast(dtypes.float32) / 255.0
return (x - self.mean) / self.std

Expand Down
19 changes: 19 additions & 0 deletions test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,25 @@ def test_mean_half_precision_overflow(self):
t.square().mean().backward()
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)

@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_softmax_dtype(self):
data = [1, 2, 3]
t = Tensor(data, dtype=dtypes.half)
tt = torch.tensor(data, dtype=torch.half)

out = t.softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.softmax(0).numpy(), rtol=1e-3)
out = t.softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
out = t.log_softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0).numpy(), rtol=1e-3)
out = t.log_softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3)

class TestImplicitFunctionTypeChange(unittest.TestCase):
def test_functions(self):
result = []
Expand Down
14 changes: 8 additions & 6 deletions test/test_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,26 @@ def test_gc(self):
(a*b).mean().backward()
assert (tensors_allocated() > 0)
del a,b
assert (tensors_allocated() == 1) # one for Tensor._device_rng_counters
assert (tensors_allocated() == 2) # one for Tensor._device_rng_counters, and one for Tensor._device_seeds
Tensor.manual_seed(0)

def test_gc_complex(self):
Tensor.manual_seed(0)
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
b = Tensor.rand(4, 4, requires_grad=True)
assert (tensors_allocated() == 4)
(a*b).mean().backward()
assert (tensors_allocated() == 5)
(a*b).mean().backward()
assert (tensors_allocated() == 6)
del b
assert (tensors_allocated() == 3)
assert (tensors_allocated() == 4)
b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
print(tensors_allocated())
(a*b).mean().backward()
print(tensors_allocated())
assert (tensors_allocated() == 5)
assert (tensors_allocated() == 6)
del b
assert (tensors_allocated() == 3)
assert (tensors_allocated() == 4)
Tensor.manual_seed(0)

def test_schedule_gc(self):
init = bufs_allocated()
Expand Down
8 changes: 8 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,14 @@ def f(a, b):
assert len(res3) == 10, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"

def test_jit_random_after_unrealized_random(self):
@TinyJit
def f(): return Tensor.rand()
Tensor.manual_seed(1234)
Tensor.rand()
res = [f().numpy() for _ in range(3)]
assert res[1] != res[2]

def test_jit_realization_and_sampling(self):
w = Tensor.eye(5)

Expand Down
8 changes: 8 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,14 @@ def test_logsumexp(self):
helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7)

def test_logcumsumexp(self):
helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=1), lambda x: x.logcumsumexp(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=-1), lambda x: x.logcumsumexp(-1), atol=1e-7, grad_atol=1e-7)

def test_sinh(self):
helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6)
# TODO: backward nan instead of inf
Expand Down
33 changes: 29 additions & 4 deletions test/test_randomness.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_rand_float16(self):
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))

@unittest.skipIf(CI and Device.DEFAULT == "NV", "gpuocelot doesn't support certain ops needed for threefry")
def test_threefly_against_reference(self):
def test_threefry_against_reference(self):
Tensor.manual_seed(1337)

# reference generated using
Expand All @@ -92,11 +92,11 @@ def test_threefly_against_reference(self):

counts = Tensor.arange(20, dtype=dtypes.uint32)
counts0, counts1 = counts.chunk(2)
r = Tensor._threefry_random_bits(1337, 0, counts0, counts1).numpy()
r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy()

np.testing.assert_allclose(jr, r)

def test_threefly_against_reference_full(self):
def test_threefry_against_reference_full(self):
Tensor.manual_seed(1337)

# reference generated using
Expand All @@ -118,7 +118,7 @@ def test_threefly_against_reference_full(self):
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)

@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
def test_threefly_tensors_cnt(self):
def test_threefry_tensors_cnt(self):
Tensor.manual_seed(1337)

Tensor.rand(20).realize()
Expand All @@ -136,6 +136,31 @@ def test_threefly_tensors_cnt(self):
assert len(Tensor._device_rng_counters) == 0
assert len(Tensor._device_seeds) == 0

@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
def test_threefry_same_kernels(self):
Tensor.manual_seed(0)

Tensor.rand(1).realize()

s = Tensor.rand(20).schedule()
s2 = Tensor.rand(20).schedule()

assert len(s) == len(s2), f"{len(s)} != {len(s2)}"
for x,y in zip(s, s2):
if not (x.ast == y.ast):
print(f"{x.ast} != {y.ast}")

Tensor.rand(1, device=f"{Device.DEFAULT}:1").realize()

s3 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule()
s4 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule()

assert len(s3) == len(s4), f"{len(s3)} != {len(s4)}"
assert len(s2) == len(s4), f"{len(s)} != {len(s3)}"
for x,y in zip(s3, s4):
if not (x.ast == y.ast):
print(f"{x.ast} != {y.ast}")

@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support")
def test_rand_bfloat16(self):
N = 128
Expand Down
Loading

0 comments on commit 0371c40

Please sign in to comment.