diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b2d526756a00..581542bd2b16 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,7 +1,7 @@ name: Unit Tests env: # increment this when downloads substantially change to avoid the internet - DOWNLOAD_CACHE_VERSION: '5' + DOWNLOAD_CACHE_VERSION: '6' RUN_PROCESS_REPLAY: 1 GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} PYTHONPATH: . diff --git a/docs/mnist.md b/docs/mnist.md index 2cd34c50df44..8aae08f241d9 100644 --- a/docs/mnist.md +++ b/docs/mnist.md @@ -98,6 +98,14 @@ timeit.repeat(step, repeat=5, number=1) So around 75 ms on T4 colab. +If you want to see a breakdown of the time by kernel: + +```python +from tinygrad import GlobalCounters, Context +GlobalCounters.reset() +with Context(DEBUG=2): step() +``` + ### Why so slow? Unlike PyTorch, tinygrad isn't designed to be fast like that. While 75 ms for one step is plenty fast for debugging, it's not great for training. Here, we introduce the first quintessentially tinygrad concept, the `TinyJit`. diff --git a/docs/tensor/ops.md b/docs/tensor/ops.md index 5fbf376ff8c8..2c04a1158999 100644 --- a/docs/tensor/ops.md +++ b/docs/tensor/ops.md @@ -13,6 +13,7 @@ ::: tinygrad.Tensor.softmax ::: tinygrad.Tensor.log_softmax ::: tinygrad.Tensor.logsumexp +::: tinygrad.Tensor.logcumsumexp ::: tinygrad.Tensor.argmax ::: tinygrad.Tensor.argmin diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 3a5f17af20e9..a011f3c1d26b 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -28,7 +28,7 @@ def train_resnet(): if getenv("LOGMLPERF"): from mlperf_logging import mllog import mlperf_logging.mllog.constants as mllog_constants - mllog.config(filename=f"result_{seed}.txt") + mllog.config(filename=f"result_resnet_{seed}.txt") mllog.config(root_dir=Path(__file__).parents[3].as_posix()) # truncate to log this. "file": "tinygrad/examples/mlperf/model_train.py" MLLOGGER = mllog.get_mllogger() if INITMLPERF: @@ -621,7 +621,7 @@ def train_bert(): from mlperf_logging import mllog import mlperf_logging.mllog.constants as mllog_constants - mllog.config(filename="bert.log") + mllog.config(filename=f"result_bert_{seed}.log") mllog.config(root_dir=Path(__file__).parents[3].as_posix()) MLLOGGER = mllog.get_mllogger() MLLOGGER.logger.propagate = False @@ -752,7 +752,12 @@ def train_bert(): else: i, train_data = start_step, get_fake_data_bert(GPUS, BS) + epoch_started = False while train_data is not None and i < train_steps and not achieved: + if not epoch_started and MLLOGGER and RUNMLPERF: + MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i+1, metadata=dict(epoch_num=i+1)) + epoch_started = True + Tensor.training = True BEAM.value = TRAIN_BEAM st = time.perf_counter() @@ -801,6 +806,8 @@ def train_bert(): # ** eval loop ** if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK): if MLLOGGER and RUNMLPERF: + epoch_started = False + MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=i+1, metadata=dict(epoch_num=i+1)) MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": 1, "epoch_count": 1, "step_num": i}) if getenv("RESET_STEP", 1): train_step_bert.reset() eval_lm_losses = [] diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh index 98bacec1a1ea..9c2d247f1684 100755 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh @@ -4,7 +4,7 @@ export PYTHONPATH="." export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 -export BEAM=4 +export BEAM=3 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh index 41bbc19740df..02d2902f1bf6 100755 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh @@ -4,7 +4,7 @@ export PYTHONPATH="." export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 -export BEAM=4 +export BEAM=3 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh index 87026b4a9277..b25865214f3a 100755 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh @@ -5,7 +5,7 @@ export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_red" export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 -export BEAM=4 +export BEAM=3 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/extra/models/bert.py b/extra/models/bert.py index 4503ed0aed2e..8c91e27a8d0f 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -51,7 +51,7 @@ def __call__(self, input_ids:Tensor, attention_mask:Tensor, masked_lm_positions: # Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315 def sparse_categorical_crossentropy(self, predictions:Tensor, labels:Tensor, ignore_index=-1): - log_probs, loss_mask = predictions.log_softmax(), (labels != ignore_index) + log_probs, loss_mask = predictions.log_softmax(dtype=dtypes.float), (labels != ignore_index) y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1]) y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1]) return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero @@ -62,13 +62,12 @@ def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_ return masked_lm_loss + next_sentence_loss def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): - valid = masked_lm_ids != 0 - masked_lm_predictions = prediction_logits.log_softmax().argmax(-1) + masked_lm_predictions = prediction_logits.log_softmax(dtype=dtypes.float).argmax(-1) masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights) - seq_relationship_predictions = seq_relationship_logits.log_softmax().argmax(-1) + seq_relationship_predictions = seq_relationship_logits.log_softmax(dtype=dtypes.float).argmax(-1) seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels) next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels) diff --git a/extra/models/clip.py b/extra/models/clip.py index 96ec072aa107..7b176fcdf3c4 100644 --- a/extra/models/clip.py +++ b/extra/models/clip.py @@ -440,7 +440,7 @@ def prepare_image(self, image:Image.Image) -> Tensor: top = (h - SIZE) // 2 image = image.crop((0, SIZE, top, top+SIZE)) - x = Tensor(np.array(image.convert('RGB'))) + x = Tensor(np.array(image.convert('RGB')), device=self.std.device) x = x.permute(2, 0, 1).cast(dtypes.float32) / 255.0 return (x - self.mean) / self.std diff --git a/test/test_dtype.py b/test/test_dtype.py index fdc0c2eee259..49506a621a4c 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -760,6 +760,25 @@ def test_mean_half_precision_overflow(self): t.square().mean().backward() np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N) + @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") + def test_softmax_dtype(self): + data = [1, 2, 3] + t = Tensor(data, dtype=dtypes.half) + tt = torch.tensor(data, dtype=torch.half) + + out = t.softmax(0) + self.assertEqual(out.dtype, dtypes.half) + np.testing.assert_allclose(out.numpy(), tt.softmax(0).numpy(), rtol=1e-3) + out = t.softmax(0, dtype=dtypes.float) + self.assertEqual(out.dtype, dtypes.float) + np.testing.assert_allclose(out.numpy(), tt.softmax(0, dtype=torch.float).numpy(), rtol=1e-3) + out = t.log_softmax(0) + self.assertEqual(out.dtype, dtypes.half) + np.testing.assert_allclose(out.numpy(), tt.log_softmax(0).numpy(), rtol=1e-3) + out = t.log_softmax(0, dtype=dtypes.float) + self.assertEqual(out.dtype, dtypes.float) + np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3) + class TestImplicitFunctionTypeChange(unittest.TestCase): def test_functions(self): result = [] diff --git a/test/test_gc.py b/test/test_gc.py index 37e632acda92..8b7979e9aa0c 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -21,24 +21,26 @@ def test_gc(self): (a*b).mean().backward() assert (tensors_allocated() > 0) del a,b - assert (tensors_allocated() == 1) # one for Tensor._device_rng_counters + assert (tensors_allocated() == 2) # one for Tensor._device_rng_counters, and one for Tensor._device_seeds + Tensor.manual_seed(0) def test_gc_complex(self): Tensor.manual_seed(0) a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) b = Tensor.rand(4, 4, requires_grad=True) - assert (tensors_allocated() == 4) - (a*b).mean().backward() assert (tensors_allocated() == 5) + (a*b).mean().backward() + assert (tensors_allocated() == 6) del b - assert (tensors_allocated() == 3) + assert (tensors_allocated() == 4) b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) print(tensors_allocated()) (a*b).mean().backward() print(tensors_allocated()) - assert (tensors_allocated() == 5) + assert (tensors_allocated() == 6) del b - assert (tensors_allocated() == 3) + assert (tensors_allocated() == 4) + Tensor.manual_seed(0) def test_schedule_gc(self): init = bufs_allocated() diff --git a/test/test_jit.py b/test/test_jit.py index 6cb13409880b..6988a5ee759a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -307,6 +307,14 @@ def f(a, b): assert len(res3) == 10, "All values should be different, rand works in jit." assert res3 != res2, "Jit rand is diff with diff seeds" + def test_jit_random_after_unrealized_random(self): + @TinyJit + def f(): return Tensor.rand() + Tensor.manual_seed(1234) + Tensor.rand() + res = [f().numpy() for _ in range(3)] + assert res[1] != res[2] + def test_jit_realization_and_sampling(self): w = Tensor.eye(5) diff --git a/test/test_ops.py b/test/test_ops.py index 342b27a415d7..c4494aa31c99 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1072,6 +1072,14 @@ def test_logsumexp(self): helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7) + def test_logcumsumexp(self): + helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) + helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=1), lambda x: x.logcumsumexp(1), atol=1e-7, grad_atol=1e-7) + helper_test_op([(45)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) + helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) + helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7) + helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=-1), lambda x: x.logcumsumexp(-1), atol=1e-7, grad_atol=1e-7) + def test_sinh(self): helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6) # TODO: backward nan instead of inf diff --git a/test/test_randomness.py b/test/test_randomness.py index 43c6ea59eeff..daa5effeacbd 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -76,7 +76,7 @@ def test_rand_float16(self): equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) @unittest.skipIf(CI and Device.DEFAULT == "NV", "gpuocelot doesn't support certain ops needed for threefry") - def test_threefly_against_reference(self): + def test_threefry_against_reference(self): Tensor.manual_seed(1337) # reference generated using @@ -92,11 +92,11 @@ def test_threefly_against_reference(self): counts = Tensor.arange(20, dtype=dtypes.uint32) counts0, counts1 = counts.chunk(2) - r = Tensor._threefry_random_bits(1337, 0, counts0, counts1).numpy() + r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy() np.testing.assert_allclose(jr, r) - def test_threefly_against_reference_full(self): + def test_threefry_against_reference_full(self): Tensor.manual_seed(1337) # reference generated using @@ -118,7 +118,7 @@ def test_threefly_against_reference_full(self): np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5) @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI") - def test_threefly_tensors_cnt(self): + def test_threefry_tensors_cnt(self): Tensor.manual_seed(1337) Tensor.rand(20).realize() @@ -136,6 +136,31 @@ def test_threefly_tensors_cnt(self): assert len(Tensor._device_rng_counters) == 0 assert len(Tensor._device_seeds) == 0 + @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI") + def test_threefry_same_kernels(self): + Tensor.manual_seed(0) + + Tensor.rand(1).realize() + + s = Tensor.rand(20).schedule() + s2 = Tensor.rand(20).schedule() + + assert len(s) == len(s2), f"{len(s)} != {len(s2)}" + for x,y in zip(s, s2): + if not (x.ast == y.ast): + print(f"{x.ast} != {y.ast}") + + Tensor.rand(1, device=f"{Device.DEFAULT}:1").realize() + + s3 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule() + s4 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule() + + assert len(s3) == len(s4), f"{len(s3)} != {len(s4)}" + assert len(s2) == len(s4), f"{len(s)} != {len(s3)}" + for x,y in zip(s3, s4): + if not (x.ast == y.ast): + print(f"{x.ast} != {y.ast}") + @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support") def test_rand_bfloat16(self): N = 128 diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index d2294f6050df..173bda904bb2 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -49,22 +49,19 @@ def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", class TestSymbolic(unittest.TestCase): def helper_test_variable(self, v, n, m, s): rendered, nmin, nmax = render(v) - if isinstance(s, set): - self.assertIn(rendered, s) - else: - self.assertEqual(rendered, s) + self.assertEqual(rendered, s) self.assertEqual(nmin, n) self.assertEqual(nmax, m) def test_cmp_simple(self): self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)") - self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))", '((a<8)!=1)'}) + self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=1)") def test_ge(self): self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "0") self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "0") - self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))", '((a<8)!=1)'}) - self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))", '((a<4)!=1)'}) + self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=1)") + self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a<4)!=1)") self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1") self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1") @@ -83,7 +80,7 @@ def test_ge_divides(self): def test_ge_divides_and(self): expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512), create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)]) - self.helper_test_variable(expr, 0, 1, {"((idx1<128) and (idx2<128))", "((idx1<128)&(idx2<128))"}) + self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))") # # bool divided by int is not allowed # expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512), # create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)]) @@ -91,7 +88,7 @@ def test_ge_divides_and(self): def test_lt_factors(self): expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512) - self.helper_test_variable(expr, 0, 1, {"(((idx1*4)+FLOAT4_INDEX)<512)", "(((FLOAT4_INDEX//4)+idx1)<128)"}) + self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)") def test_div_reduction(self): self.helper_test_variable(Variable("a", 2, 3)//2, 1, 1, "1") @@ -133,19 +130,19 @@ def test_factorize_no_mul(self): self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)") def test_neg(self): - self.helper_test_variable(-Variable("a", 0, 8), -8, 0, {"(a*-1)", "(a*(-1))"}) + self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)") def test_add_1(self): - self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, {"(1+a)", "(a+1)"}) + self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(a+1)") def test_add_num_1(self): - self.helper_test_variable(Variable("a", 0, 8)+NumNode(1), 1, 9, {"(1+a)", "(a+1)"}) + self.helper_test_variable(Variable("a", 0, 8)+NumNode(1), 1, 9, "(a+1)") def test_sub_1(self): - self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, {"(-1+a)", "(a+(-1))", "(a+-1)"}) + self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(a+-1)") def test_sub_num_1(self): - self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, {"(-1+a)", "(a+(-1))", "(a+-1)"}) + self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, "(a+-1)") def test_add_self(self): a = Variable("a", 0, 8) @@ -200,10 +197,10 @@ def test_sum_div_mod_factor(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0") def test_sum_div_some_factor(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, {"(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))"}) + self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))") def test_sum_div_trim_const(self): - self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, {"((1+a+b)//4)", "((a+b+1)//4)"}) + self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "((a+b+1)//4)") def test_sum_div_some_partial_factor(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)") @@ -218,7 +215,7 @@ def test_mod_factor(self): def test_mod_to_sub(self): # This is mod reduction - self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, {"(-1+a)", "(a+(-1))", "(a+-1)"}) + self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)") def test_sum_div_const(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a") @@ -228,8 +225,7 @@ def test_sum_div_const_big(self): def test_sum_lt_fold(self): self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]), 16), 0, 1, "(a<4)") - self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]), 16), 0, 1, - {"(((a*4)+b)<16)", "(((b//4)+a)<4)"}) + self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]), 16), 0, 1, "(((a*4)+b)<16)") # TODO: fix with self.assertRaises(AssertionError): self.helper_test_variable(create_lt_node(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]), (4 * 67)), 0, 1, "(a<23)") @@ -243,7 +239,7 @@ def test_mul_mod_small(self): def test_mod_mod(self): self.helper_test_variable((Variable("a", 0, 31)%12)%4, 0, 3, "(a%4)") self.helper_test_variable(((4*Variable("a", 0, 31)) % 12) % 4, 0, 0, "0") - self.helper_test_variable(((5*Variable("a", 0, 31)) % 12) % 5, 0, 4, {"(((a*5)%12)%5)", "(((5*a)%12)%5)"}) + self.helper_test_variable(((5*Variable("a", 0, 31)) % 12) % 5, 0, 4, "(((a*5)%12)%5)") self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)") def test_mul_mul(self): @@ -252,19 +248,19 @@ def test_mul_mul(self): def test_mul_lt(self): self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)") self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)") - self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, {"((a*-1)<0)", "((a*(-1))<0)"}) - self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, {"((a*-1)<-2)", "((a*(-1))<(-2))", '((a<3)!=1)'}) - self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))", '((a<4)!=1)'}) + self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, "((a*-1)<0)") + self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a<3)!=1)") + self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a<4)!=1)") def test_div_div(self): self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)") def test_distribute_mul(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, {"((a*3)+(b*3))", "((a+b)*3)"}) - self.helper_test_variable((1+Variable("a", 0, 3))*(-2)+12, 4, 10, {"((a*-2)+10)", "((a*(-2))+10)"}) + self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))") + self.helper_test_variable((1+Variable("a", 0, 3))*(-2)+12, 4, 10, "((a*-2)+10)") def test_mod_mul_sum(self): - self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, {"(a+b)", "(b+a)"}) + self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(b+a)") def test_sum_0(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7)]), 0, 7, "a") @@ -310,17 +306,17 @@ def test_and_remove(self): self.helper_test_variable(Node.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a") def test_mod_factor_negative(self): - self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, {"((27+a)%28)", "((a+27)%28)"}) - self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, {"((27+a)%28)", "((a+27)%28)"}) + self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)") + self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)") def test_sum_combine_num(self): - self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, {"(6+a)", "(a+6)"}) + self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(a+6)") def test_sum_num_hoisted_and_factors_cancel_out(self): self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1") def test_div_cancel(self): - self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, {"(-1+b)", "(b+(-1))", "(b+-1)"}) + self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)") def test_mod_cancel(self): self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)") @@ -330,8 +326,8 @@ def test_mul_div(self): def test_add_div(self): # careful about the lower bounds and upper bounds - self.helper_test_variable((Variable("a", 0, 5)-2)//4, -1, 0, {"(((a+2)//4)+(-1))", "(((a+2)//4)+-1)"}) - self.helper_test_variable((Variable("a", 0, 5)-1)//4, -1, 1, {"(((a+3)//4)+(-1))", "(((a+3)//4)+-1)"}) + self.helper_test_variable((Variable("a", 0, 5)-2)//4, -1, 0, "(((a+2)//4)+-1)") + self.helper_test_variable((Variable("a", 0, 5)-1)//4, -1, 1, "(((a+3)//4)+-1)") self.helper_test_variable((Variable("a", 0, 5))//4, 0, 1, "(a//4)") self.helper_test_variable((Variable("a", 0, 5)+1)//4, 0, 1, "((a+1)//4)") self.helper_test_variable((Variable("a", 0, 5)+2)//4, 0, 1, "((a+2)//4)") @@ -357,15 +353,15 @@ def test_div_into_mod(self): # TODO: simplify the expression def test_div_neg_cancel(self): - self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, {"((((idx*(-1))+199)//(-4))+50)", "((((idx*-1)+199)//-4)+50)"}) - self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, {"((((idx*(-1))+200)//(-4))+50)", "((((idx*-1)+200)//-4)+50)"}) - self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, {"((((idx*(-1))+201)//(-4))+50)", "((((idx*-1)+201)//-4)+50)"}) + self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((((idx*-1)+199)//-4)+50)") + self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((((idx*-1)+200)//-4)+50)") + self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((((idx*-1)+201)//-4)+50)") def test_sum_div_big_const(self): gidx0 = Variable("gidx0", 0, 24) - self.helper_test_variable((gidx0+19)//20, 0, 2, {"((19+gidx0)//20)", "((gidx0+19)//20)"}) + self.helper_test_variable((gidx0+19)//20, 0, 2, "((gidx0+19)//20)") self.helper_test_variable((gidx0+20)//20, 1, 2, "((gidx0//20)+1)") - self.helper_test_variable((gidx0+21)//20, 1, 2, {"(((1+gidx0)//20)+1)", "(((gidx0+1)//20)+1)"}) + self.helper_test_variable((gidx0+21)//20, 1, 2, "(((gidx0+1)//20)+1)") def test_sum_div_complex1(self): gidx0 = Variable("gidx0", 0, 24) @@ -375,8 +371,7 @@ def test_sum_div_complex1(self): lidx1 = Variable("lidx1", 0, 15) lidx2 = Variable("lidx2", 0, 3) alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10 - self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192, {"((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(gidx1*8)+(gidx2*32)+(lidx0*16))", - "((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(gidx2*32)+(gidx1*8)+(lidx0*16))"}) + self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192, "((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(gidx2*32)+(gidx1*8)+(lidx0*16))") def test_sum_div_complex2(self): gidx0 = Variable("gidx0", 0, 7) @@ -427,9 +422,8 @@ def test_div_neg_then_neg(self): lidx1 = Variable("lidx1", 0, 7) alu2 = -lidx0-lidx1 self.helper_test_variable((((alu2+14)//(-32))+4), 4, 4, "4") - self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, {"(-4)", "-4"}) - self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, {"((((lidx0*(-1))+(lidx1*(-1))+134)//(-32))+4)", - "((((lidx0*-1)+(lidx1*-1)+134)//-32)+4)"}) + self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, "-4") + self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "((((lidx0*-1)+(lidx1*-1)+134)//-32)+4)") self.helper_test_variable((((alu2+142)//(-32))+4), 0, 0, "0") self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0") self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0") @@ -453,12 +447,12 @@ def test_gated_load(self): idx = Variable("idx", 0, 24) self.helper_test_variable(idx//4, 0, 6, "(idx//4)") # TODO: simplify the true branch - self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, {"((idx<4)?(idx//4):(-1))", "((idx<4)?(idx//4):-1)"}) + self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx<4)?(idx//4):-1)") def test_idiv_lt(self): idx = Variable("idx", 0, 24) self.helper_test_variable((idx//4).lt(3), 0, 1, "(idx<12)") - self.helper_test_variable((idx//-4).lt(-3), 0, 1, {"((idx//(-4))<(-3))", "((idx//-4)<-3)"}) + self.helper_test_variable((idx//-4).lt(-3), 0, 1, "((idx//-4)<-3)") def test_simplex_lt(self): a = Variable("a", 0, 3) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index ce1de21ddb1d..1c468b849b1a 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -1,10 +1,10 @@ from __future__ import annotations from typing import Optional, Tuple, Dict, List, Set, cast, TYPE_CHECKING, Any, DefaultDict, Callable -import functools, itertools, heapq, math, operator +import functools, itertools, heapq, operator from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType, DType from tinygrad.ops import UnaryOps, BinaryOps, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element -from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps, simple_pm +from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps, symbolic_flat, is_irreducible, _get_chain from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES if TYPE_CHECKING: from tinygrad.renderer import Renderer @@ -75,108 +75,8 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp): (UPat((UOps.BARRIER, UOps.SINK), src=UPat(UOps.STORE, src=(UPat.var("buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded), ]) -# ***** mod ***** - -def _get_chain(x:UOp, sep:BinaryOps): - if x.op is UOps.ALU and x.arg is sep: - for s in x.src: yield from _get_chain(s, sep) - else: yield x - -def mod_folding(x:UOp, c:int) -> Optional[UOp]: - # simplify x % c, None means no change - - # simple cancel mod case - if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c - - remainder, something_changed = [], False - for u in _get_chain(x, BinaryOps.ADD): - if (factor:=u.const_factor())%c != factor: - divides = u.divides(factor)*(factor%c) - assert divides is not None - remainder.append(divides) - something_changed = True - elif u.op is UOps.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is UOps.CONST and s1.arg%c == 0: - remainder.append(u.src[0]) - something_changed = True - else: remainder.append(u) - if not something_changed: return None - return functools.reduce(operator.add, remainder)%c if remainder else x.const_like(0) - -def div_folding(x:UOp, c:int) -> Optional[UOp]: - # simplify x // c, None means no change - - # simple cancel div case - if 0 <= x.vmin and x.vmax < c: return x.const_like(0) - - quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1 - for u in _get_chain(x, BinaryOps.ADD): - if u.op is UOps.CONST: - # add all const together first - if rem_const != 0: something_changed = True - rem_const += u.arg - elif (factor:=u.const_factor())%c == 0: - if factor: - divides = u.divides(c) - assert divides is not None - quotient.append(divides) - something_changed = True - else: - # divisor is the smallest common divisor of all MULs - if u.op is UOps.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor - remainder.append(u) - gcd = math.gcd(gcd, factor) - - # handle the const - if rem_const%c != rem_const: - something_changed = True - quotient.append(x.const_like(rem_const//c)) - rem_const = rem_const%c - if rem_const != 0: remainder.append(x.const_like(rem_const)) - - # x // c -> quotient + (remainder // div) // (c // div) - div = gcd if gcd > 1 else divisor - - if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None - rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None - quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None - if quo is None: return x.const_like(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div) - return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo - -def lt_folding(x:UOp, c:int) -> Optional[UOp]: - return cast(UOp, x.divides(g)).lt(c//g) if ((g:=math.gcd(x.const_factor(), c)) > 1) else None - -def fold_unrolled_divs(divs:UOp): - # div pattern in unrolled arange - # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x - add_chain, seen_const, ans = list(_get_chain(divs, BinaryOps.ADD)), [], None - for u in add_chain: - if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST and u.src[1].arg==len(add_chain)): return None - # assumed CONST is the last of an ADD - if (s0:=u.src[0]).op is UOps.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is UOps.CONST and s0.src[1].op is UOps.CONST: - seen_const.append(s0.src[1].arg) - s0 = s0.src[0] - else: seen_const.append(0) - if ans is None: ans = s0 - if ans is not s0: return None - return ans if ans is not None and sorted(seen_const)==list(range(len(add_chain))) else None - # ***** image load valid simplification ***** -def is_irreducible(u:UOp): return u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE) - -def canonicalize_simplex(X:UOp) -> Optional[UOp]: - # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. - # returns x0 + x1 + ... in such case, or None if not - changed, ret = False, [] - for u in _get_chain(X, BinaryOps.ADD): - # assumed the const is the last src of MUL - if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0: - changed = True - u = u.src[0] - if not (is_irreducible(u) and u.vmin >= 0): return None - ret.append(u) - return functools.reduce(operator.add, ret) if changed else None - def is_increasing(f:UOp): # is f a monotonically increasing function regards its input if f.op is UOps.CONST or is_irreducible(f): return True @@ -370,7 +270,7 @@ def no_vectorized_wmma(wmma:UOp): return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex)) # this is symbolic 2.0 -sym = simple_pm+PatternMatcher([ +sym = symbolic_flat+PatternMatcher([ # self ASSIGN is just self (UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x), # ASSIGN to global is just self @@ -419,8 +319,6 @@ def no_vectorized_wmma(wmma:UOp): .lt(UPat.cvar("compval")).ne(UPat(UOps.CONST, name="ne", arg=True)) .where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), - # unrolled arange div folding - (UPat(UOps.ALU, name="divs", src=[UPat(), UPat(UOps.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs), # indexing, with cast or where (UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()* UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))), @@ -430,28 +328,12 @@ def no_vectorized_wmma(wmma:UOp): name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), # GEP/CAST const rules (UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)), - # ** combine terms (opinionated) ** - (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y - # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue - ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), # ** self folding ** # cast NOOP (NOTE: it's str to deal with PtrDType) (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), (UPat(UOps.REDUCE, src=(UPat.var("x"),)), lambda x: x), # a REDUCE without ranges is a NOOP # ** load/store folding ** (UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.load(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)), - # *** rules from symbolic *** - # generic lt folding - (UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None), - # canonicalize a simplex with positive coefficients > 0 - # not x < 1 -> X > 0 - (UPat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), - # ** div ** - # # div folding - (UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None), - # ** mod ** - # mod folding - (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), # x!=0 -> (bool)x (UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))), # TODO: can do the invert of this (flip alt/load) when we fix double ops diff --git a/tinygrad/engine/lazy.py b/tinygrad/engine/lazy.py index d274981627a1..b7c954c9e4a8 100644 --- a/tinygrad/engine/lazy.py +++ b/tinygrad/engine/lazy.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Union, Optional, Any, Tuple, List, get_args -from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype +from tinygrad.dtype import dtypes, DType, ConstType, to_dtype from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU from tinygrad.ops import identity_element, MathTrait, resolve, UOp @@ -10,7 +10,7 @@ from weakref import ref, ReferenceType, WeakValueDictionary lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary() -def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), +def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))): if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None dtype = to_dtype(dtype) @@ -25,7 +25,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[ view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "DSP", "DISK"} class LazyBuffer(MathTrait): - def __init__(self, device:str, st:ShapeTracker, dtype:DTypeLike, + def __init__(self, device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None): self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata @@ -68,7 +68,7 @@ def base(self) -> LazyBuffer: return self._base if self._base is not None else s def lbs(self) -> List[LazyBuffer]: return [self] @staticmethod - def metaop(op, shape:Tuple[sint,...], dtype:DTypeLike, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer: + def metaop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer: assert isinstance(src, tuple) return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 4a19c53327c0..69f047efec81 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -41,7 +41,8 @@ def fully_flatten(l): return [l] def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst -def round_up(num, amt:int): return (num+amt-1)//amt * amt +def ceildiv(num:int, amt:int) -> int: return -int(num//-amt) +def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF) def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32) def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d6f1529b22b3..0ad12675d10d 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -10,7 +10,6 @@ if TYPE_CHECKING: from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.shapetracker import ShapeTracker - from tinygrad.codegen.kernel import Kernel # wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses class FastEnum(IntEnum): @@ -209,7 +208,7 @@ def key(self) -> bytes: def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))") def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg # *** uop evaluation *** - def simplify(self): return graph_rewrite(self, simple_pm) + def simplify(self): return graph_rewrite(self, symbolic) def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is UOps.CONST else ret def _eval(self, dtype, expected_type) -> ConstType: assert self.dtype in dtype, f"eval with wrong dtype {self}" @@ -587,9 +586,9 @@ def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]: class TrackedRewriteContext: loc: Tuple[str, int] # location that called graph_rewrite sink: UOp # the sink passed into the rewrite - kernel: Optional[Kernel] = None # the kernel being rewritten rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat) -contexts: List[TrackedRewriteContext] = [] +rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = [] +contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = [] class TrackedPatternMatcher(PatternMatcher): def __init__(self, patterns:List[Tuple[UPat, Callable]]): super().__init__(patterns) @@ -610,7 +609,7 @@ def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]: match_stats[p][2] += (et:=time.perf_counter()-st) match_stats[p][3] += et if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable()) - if TRACK_MATCH_STATS >= 2 and contexts and isinstance(ret, UOp): contexts[-1].rewrites.append((uop, ret, p)) + if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].rewrites.append((uop, ret, p)) return ret # NOTE: if it returns None, we keep trying to match match_stats[p][2] += time.perf_counter()-st return None @@ -622,7 +621,7 @@ def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]: def print_match_stats(): if TRACK_MATCH_STATS >= 2: with open("/tmp/rewrites.pkl", "wb") as f: - print(f"rewrote {len(contexts)} graphs and applied {sum(len(x.rewrites) for x in contexts)} rules, saved to /tmp/rewrites.pkl") + print(f"rewrote {len(contexts)} graphs and applied {sum(len(r.rewrites) for _,x in contexts for r in x)} rules, saved to /tmp/rewrites.pkl") pickle.dump(contexts, f) if getenv("VIZ"): os.environ["VIZ"] = "0" @@ -656,8 +655,10 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: # get Kernel we are rewriting in the context of frm_walk: Optional[FrameType] = frm while frm_walk is not None and not isinstance(kernel:=frm_walk.f_locals.get("self", None), Kernel): kernel, frm_walk = None, frm_walk.f_back - contexts.append(TrackedRewriteContext((frm.f_code.co_filename, frm.f_lineno), sink, kernel)) - return RewriteContext(pm, ctx).rewrite(sink) + rewrite_stack.append((kernel, [TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)])) + ret = RewriteContext(pm, ctx).rewrite(sink) + if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop()) + return ret # ***** uop type spec ***** @@ -745,7 +746,105 @@ def type_verify(uops:List[UOp]): # *** most of symbolic lives here now *** -simple_pm = PatternMatcher([ +def _get_chain(x:UOp, sep:BinaryOps): + if x.op is UOps.ALU and x.arg is sep: + for s in x.src: yield from _get_chain(s, sep) + else: yield x + +def mod_folding(x:UOp, c:int) -> Optional[UOp]: + # simplify x % c, None means no change + + # simple cancel mod case + if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c + + remainder, something_changed = [], False + for u in _get_chain(x, BinaryOps.ADD): + if (factor:=u.const_factor())%c != factor: + divides = u.divides(factor)*(factor%c) + assert divides is not None + remainder.append(divides) + something_changed = True + elif u.op is UOps.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is UOps.CONST and s1.arg%c == 0: + remainder.append(u.src[0]) + something_changed = True + else: remainder.append(u) + if not something_changed: return None + return functools.reduce(operator.add, remainder)%c if remainder else x.const_like(0) + +def div_folding(x:UOp, c:int) -> Optional[UOp]: + # simplify x // c, None means no change + + # simple cancel div case + if 0 <= x.vmin and x.vmax < c: return x.const_like(0) + + quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1 + for u in _get_chain(x, BinaryOps.ADD): + if u.op is UOps.CONST: + # add all const together first + if rem_const != 0: something_changed = True + rem_const += u.arg + elif (factor:=u.const_factor())%c == 0: + if factor: + divides = u.divides(c) + assert divides is not None + quotient.append(divides) + something_changed = True + else: + # divisor is the smallest common divisor of all MULs + if u.op is UOps.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor + remainder.append(u) + gcd = math.gcd(gcd, factor) + + # handle the const + if rem_const%c != rem_const: + something_changed = True + quotient.append(x.const_like(rem_const//c)) + rem_const = rem_const%c + if rem_const != 0: remainder.append(x.const_like(rem_const)) + + # x // c -> quotient + (remainder // div) // (c // div) + div = gcd if gcd > 1 else divisor + + if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None + rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None + quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None + if quo is None: return x.const_like(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div) + return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo + +def lt_folding(x:UOp, c:int) -> Optional[UOp]: + return cast(UOp, x.divides(g)).lt(c//g) if ((g:=math.gcd(x.const_factor(), c)) > 1) else None + +def fold_unrolled_divs(divs:UOp): + # div pattern in unrolled arange + # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x + add_chain, seen_const, ans = list(_get_chain(divs, BinaryOps.ADD)), [], None + for u in add_chain: + if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST and u.src[1].arg==len(add_chain)): return None + # assumed CONST is the last of an ADD + if (s0:=u.src[0]).op is UOps.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is UOps.CONST and s0.src[1].op is UOps.CONST: + seen_const.append(s0.src[1].arg) + s0 = s0.src[0] + else: seen_const.append(0) + if ans is None: ans = s0 + if ans is not s0: return None + return ans if ans is not None and sorted(seen_const)==list(range(len(add_chain))) else None + +def is_irreducible(u:UOp): return u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE) + +def canonicalize_simplex(X:UOp) -> Optional[UOp]: + # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. + # returns x0 + x1 + ... in such case, or None if not + changed, ret = False, [] + for u in _get_chain(X, BinaryOps.ADD): + # assumed the const is the last src of MUL + if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0: + changed = True + u = u.src[0] + if not (is_irreducible(u) and u.vmin >= 0): return None + ret.append(u) + return functools.reduce(operator.add, ret) if changed else None + +symbolic = PatternMatcher([ # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly (UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y), (UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y), @@ -816,6 +915,27 @@ def type_verify(uops:List[UOp]): # ** move mul consts to end (NOTE: this is still happening before constant folding) ** (UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x*c1 if x.op not in (UOps.CONST, UOps.VCONST) else None), (UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), + # *** rules from symbolic *** + # unrolled arange div folding + (UPat(UOps.ALU, name="divs", src=[UPat(), UPat(UOps.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs), + # generic lt folding + (UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None), + # canonicalize a simplex with positive coefficients > 0 + # not x < 1 -> X > 0 + (UPat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), + # ** div ** + # # div folding + (UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None), + # ** mod ** + # mod folding + (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), +]) + +symbolic_flat = symbolic+PatternMatcher([ + # ** combine terms (opinionated) ** + (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y + # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue + ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ]) # for debug diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 806dbecce88e..65c949b2ba6c 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -6,8 +6,7 @@ from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.view import View, strides_for_shape from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, resolve -from tinygrad.codegen.uopgraph import sym, _get_chain +from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, resolve, _get_chain, symbolic_flat def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]: @@ -90,15 +89,15 @@ def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]: if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides ret: List[Optional[sint]] = [None] * len(self.shape) idx, valid = self.to_indexed_uops() - idx = graph_rewrite(idx, pm=sym) + idx = graph_rewrite(idx, symbolic_flat) for c in _get_chain(idx, BinaryOps.ADD): if c.op is UOps.RANGE: ret[c.arg] = 1 if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[0].op is UOps.RANGE and c.src[1].op is UOps.CONST: ret[c.src[0].arg] = c.src[1].arg if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[1].op is UOps.RANGE and c.src[0].op is UOps.CONST: ret[c.src[1].arg] = c.src[0].arg - used_ranges = [x.arg for x in graph_rewrite(idx, pm=sym).sparents if x.op is UOps.RANGE] + used_ranges = [x.arg for x in graph_rewrite(idx, symbolic_flat).sparents if x.op is UOps.RANGE] ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)] if not ignore_valid: - masked_axis = [x.arg for x in graph_rewrite(valid, pm=sym).sparents if x.op is UOps.RANGE] + masked_axis = [x.arg for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is UOps.RANGE] ret = [None if i in masked_axis else x for i,x in enumerate(ret)] return tuple(ret) @@ -106,7 +105,7 @@ def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st def axis_is_masked(self, axis:int) -> bool: _, valid = self.to_indexed_uops() - return axis in [x.arg for x in graph_rewrite(valid, sym).sparents if x.op is UOps.RANGE] + return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is UOps.RANGE] def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 86432564c7d2..50da24a6859b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup -from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA +from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv from tinygrad.multi import MultiLazyBuffer from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps from tinygrad.device import Device, Buffer, BufferOptions @@ -357,17 +357,17 @@ def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optiona """ assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer" - canonical_devices, bounds = tuple(Device.canonicalize(x) for x in devices), None + devices, bounds = tuple(Device.canonicalize(x) for x in devices), None if axis is not None: if axis < 0: axis += len(self.shape) if splits is None: - sz = round_up(self.shape[axis], len(devices)) // len(devices) - splits = tuple([max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]) + if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}") + sz = round_up(total, len(devices)) // len(devices) + splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))]) assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape" boundaries = tuple(itertools.accumulate(splits)) bounds = tuple(zip((0,) + boundaries, boundaries)) - return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis, bounds), - device=canonical_devices, requires_grad=self.requires_grad) + return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad) def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None): """ @@ -389,10 +389,11 @@ def from_uop(y:UOp, **kwargs) -> Tensor: @staticmethod def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs): + dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float if isinstance(device, tuple): - return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \ - for d in device], None), device, dtype, **kwargs) - return Tensor(LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs) + return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None), + device, dtype, **kwargs) + return Tensor(LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs) @staticmethod def empty(*shape, **kwargs): @@ -425,7 +426,7 @@ def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor: return r _seed: int = int(time.time()) - _device_seeds: Dict[str, int] = {} + _device_seeds: Dict[str, Tensor] = {} _device_rng_counters: Dict[str, Tensor] = {} @staticmethod def manual_seed(seed=0): @@ -446,9 +447,8 @@ def manual_seed(seed=0): Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {} @staticmethod - def _threefry_random_bits(key0, key1, counts0, counts1): + def _threefry_random_bits(key, counts0, counts1): x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64) - key = (Tensor([key0], device=x.device, dtype=dtypes.uint64, requires_grad=False) << 32) | key1 x = F.Threefry.apply(*x._broadcasted(key)) counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32) return counts0.cat(counts1) @@ -477,21 +477,23 @@ def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, cont # generate per device seeds and rng counter if we haven't seen this device yet if device not in Tensor._device_seeds: - Tensor._device_seeds[device] = int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff + Tensor._device_seeds[device] = Tensor([((Tensor._seed & 0xffffffff) << 32) \ + | int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff], + device=device, dtype=dtypes.uint64, requires_grad=False) Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False) had_counter = False else: had_counter = True # if shape has 0, return zero tensor - if (num := math.ceil(((num_ := prod(shape)) * dtype.itemsize) / 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs) + if (num := ceildiv(((num_ := prod(shape)) * dtype.itemsize), 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs) # increment rng counter for devices - if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num) + if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous() # threefry random bits - counts0 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device]) - counts1 = counts0 + math.ceil(num / 2) - bits = Tensor._threefry_random_bits(Tensor._seed, Tensor._device_seeds[device], counts0, counts1)[:num] + counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device]) + counts1 = counts0 + ceildiv(num, 2) + bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num] # bitcast to uint with same number of bits _, nmant = dtypes.finfo(dtype) @@ -591,8 +593,8 @@ def arange(start, stop=None, step=1, **kwargs) -> Tensor: assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}" dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int) # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs - if (stop-start)/step <= 0: return Tensor([], dtype=dtype, **kwargs) - return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype) + if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs) + return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype) @staticmethod def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor: @@ -1076,6 +1078,7 @@ def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor: if any(abs(st) != 1 for st in strides): strides = tuple(abs(s) for s in strides) # pad shape to multiple of stride + if not all_int(ret.shape): raise RuntimeError("symbolic shape not supprted") ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides))) ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides)))) ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2]) @@ -1298,7 +1301,7 @@ def chunk(self, chunks:int, dim:int=0) -> List[Tensor]: assert all_int(self.shape), f"does not support symbolic shape {self.shape}" assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}" dim = self._resolve_dim(dim) - return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim)) + return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim)) def squeeze(self, dim:Optional[int]=None) -> Tensor: """ @@ -1674,12 +1677,13 @@ def std_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, """ return self.std(axis, keepdim, correction), self.mean(axis, keepdim) - def _softmax(self, axis): - m = self - self.max(axis=axis, keepdim=True) + def _softmax(self, axis, dtype:Optional[DTypeLike]=None): + x = self.cast(dtype) if dtype is not None else self + m = x - x.max(axis=axis, keepdim=True) e = m.exp() return m, e, e.sum(axis=axis, keepdim=True) - def softmax(self, axis=-1): + def softmax(self, axis=-1, dtype:Optional[DTypeLike]=None): """ Applies the softmax function to the tensor along the specified axis. @@ -1699,10 +1703,10 @@ def softmax(self, axis=-1): print(t.softmax(axis=0).numpy()) ``` """ - _, e, ss = self._softmax(axis) + _, e, ss = self._softmax(axis, dtype) return e.div(ss) - def log_softmax(self, axis=-1): + def log_softmax(self, axis=-1, dtype:Optional[DTypeLike]=None): """ Applies the log-softmax function to the tensor along the specified axis. @@ -1722,7 +1726,7 @@ def log_softmax(self, axis=-1): print(t.log_softmax(axis=0).numpy()) ``` """ - m, _, ss = self._softmax(axis) + m, _, ss = self._softmax(axis, dtype) return m - ss.log() def logsumexp(self, axis=None, keepdim=False): @@ -1752,6 +1756,33 @@ def logsumexp(self, axis=None, keepdim=False): m = self.max(axis=axis, keepdim=True) return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis) + def logcumsumexp(self, axis=0): + """ + Computes the log-cumsum-exp of the tensor along the specified axis or axes. + + The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials. + + You can pass in the `axis` keyword argument to control the axis along which + the log-cum-sum-exp is computed. + + ```python exec="true" source="above" session="tensor" result="python" + Tensor.manual_seed(42) + t = Tensor.randn(2, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.logcumsumexp().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.logcumsumexp(axis=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.logcumsumexp(axis=1).numpy()) + ``` + """ + m = self.max(axis=axis, keepdim=True) + return (self - m).exp().cumsum(axis=axis).log() + m + def argmax(self, axis=None, keepdim=False): """ Returns the indices of the maximum value of the tensor along the specified axis. @@ -1896,10 +1927,10 @@ def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilat s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_)) assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):] - o_ = [math.ceil((i - d * (k-1))/s) for i,d,k,s in zip(i_, d_, k_, s_)] + o_ = [ceildiv(i - d * (k-1), s) for i,d,k,s in zip(i_, d_, k_, s_)] if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_): # repeats such that we don't need padding - xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)]) + xup = self.repeat([1]*len(noop_) + [ceildiv(k*(i+d), i) for k,i,d in zip(k_, i_, d_)]) # handle dilation xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_))) # handle stride @@ -2101,12 +2132,11 @@ def cumsum(self, axis:int=0) -> Tensor: # TODO: someday the optimizer will find this on it's own # for now this is a two stage cumsum SPLIT = 256 - if self.shape[axis] <= SPLIT*2: return self._cumsum(axis) - ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0)) - ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1) + if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumsum(axis) + ret = self.transpose(axis,-1).pad2d((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1) base_add = ret[..., -1]._cumsum(-1, _first_zero=True) base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1]) - def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1) + def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1) return fix(ret) + fix(base_add) @staticmethod diff --git a/viz/index.html b/viz/index.html index ca2e45d37574..88a301730088 100644 --- a/viz/index.html +++ b/viz/index.html @@ -269,7 +269,7 @@ ret = await (await fetch(`/kernels?kernel=${currentKernel}&idx=${currentUOp}`)).json(); cache[cacheKey] = ret; } - renderGraph(ret.graphs[currentRewrite], ret.changed_nodes[currentRewrite]); + renderGraph(ret.graphs[currentRewrite], currentRewrite == 0 ? [] : ret.changed_nodes[currentRewrite-1]); // ***** RHS metadata const metadata = document.querySelector(".container.metadata"); metadata.innerHTML = ""; diff --git a/viz/serve.py b/viz/serve.py index 616d5e34339a..3c56c105bd85 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -1,26 +1,28 @@ #!/usr/bin/env python3 from collections import defaultdict -from typing import DefaultDict, Dict, List, Optional, Tuple +from typing import Any, DefaultDict, Dict, List, Optional, Tuple import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, multiprocessing, functools from dataclasses import asdict from urllib.parse import parse_qs, urlparse from http.server import HTTPServer, BaseHTTPRequestHandler -from tinygrad.helpers import getenv, to_function_name -from tinygrad.ops import TrackedRewriteContext, UOp, UOps, UPat, lines +from tinygrad.codegen.kernel import Kernel +from tinygrad.helpers import getenv, to_function_name, tqdm +from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines from tinygrad.engine.graph import uops_colors, word_wrap from viz.spec import GraphRewriteDetails, GraphRewriteMetadata -def reconstruct_graph(sink:UOp, rewrites:List[Tuple[UOp, UOp, UPat]]) -> Tuple[List[UOp], List[List[str]], List[List[int]]]: - uops: List[UOp] = [sink] +def reconstruct_graph(ctx:TrackedRewriteContext) -> Tuple[List[UOp], List[List[str]], List[List[int]]]: + uops: List[UOp] = [ctx.sink] diffs: List[List[str]] = [] - changed_nodes: List[List[int]] = [[]] - seen_replaces: Dict[bytes, UOp] = {} - for i, (first, rewritten, _) in enumerate(rewrites): + changed_nodes: List[List[int]] = [] + seen_replaces: Dict[UOp, UOp] = {} + for i, (first, rewritten, upat) in enumerate(ctx.rewrites): # first, rewrite this UOp with the current rewrite + all the seen rewrites before this - seen_replaces[first.key] = rewritten + seen_replaces[first] = rewritten new_sink = replace_uop(uops[-1], {**seen_replaces}) # sanity check - assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}" + if new_sink is uops[-1]: + raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}") # update ret data changed_nodes.append([id(x) for x in rewritten.sparents if x.op is not UOps.CONST]) diffs.append(list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))) @@ -41,22 +43,25 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff")) return graph -def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp: - if (found:=replaces.get(base.key)) is not None: return found - new_srcs = tuple(replace_uop(x, replaces) for x in base.src) - replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base +def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp: + if (found:=replaces.get(base)) is not None: return found + replaces[base] = ret = base.replace(src=tuple(replace_uop(x, replaces) for x in base.src)) return ret -def load_kernels(contexts) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, TrackedRewriteContext]]]: +def load_kernels(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, \ + TrackedRewriteContext, Any]]]: kernels = defaultdict(list) - for ctx in contexts: - name = to_function_name(ctx.kernel.name) if ctx.kernel is not None else None - upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites] - kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx)) + for k,rewrites in contexts: + if isinstance(k, Kernel): name = to_function_name(k.name) + else: name = None + for ctx in rewrites: + if ctx.sink.op is UOps.CONST: continue + upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites] + kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx, k)) return kernels @functools.lru_cache(None) -def get_src(k) -> Optional[str]: return k.to_program().src if k else None +def get_src(k) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None class Handler(BaseHTTPRequestHandler): def do_GET(self): @@ -78,10 +83,10 @@ def do_GET(self): self.end_headers() query = parse_qs(url.query) if (qkernel:=query.get("kernel")) is not None: - metadata, ctx = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])] - graphs, diffs, changed_nodes = reconstruct_graph(ctx.sink, ctx.rewrites) + metadata, ctx, k = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])] + graphs, diffs, changed_nodes = reconstruct_graph(ctx) ret = json.dumps(asdict(GraphRewriteDetails(**asdict(metadata), graphs=list(map(uop_to_json, graphs)), - diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(ctx.kernel)))).encode() + diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(k)))).encode() else: ret = json.dumps([list(map(lambda x:asdict(x[0]), v)) for v in kernels.values()]).encode() else: self.send_response(404) @@ -101,9 +106,12 @@ def reloader(): if __name__ == "__main__": multiprocessing.current_process().name = "VizProcess" # disallow opening of devices print("*** viz is starting") - with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f) + with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = pickle.load(f) print("*** unpickled saved rewrites") kernels = load_kernels(contexts) + if getenv("FUZZ_VIZ"): + for v in tqdm(kernels.values()): + for _,ctx,_ in v: reconstruct_graph(ctx) print("*** loaded kernels") server = HTTPServer(('', 8000), Handler) st = time.perf_counter() diff --git a/viz/test_viz.py b/viz/test_viz.py index bddbaf2196d2..089b83dc11fe 100644 --- a/viz/test_viz.py +++ b/viz/test_viz.py @@ -1,19 +1,17 @@ -from typing import List +from typing import Any, List, Tuple import unittest import os, itertools - -from viz.spec import GraphRewriteMetadata os.environ["TRACK_MATCH_STATS"] = "2" os.environ["PRINT_MATCH_STATS"] = "0" -from extra.models.resnet import ResNet50 from tinygrad import Tensor from tinygrad.engine.realize import lower_schedule from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps from tinygrad.dtype import dtypes, PtrDType -from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv +from tinygrad.helpers import Context, all_same, DEBUG, getenv from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding from test.external.process_replay.helpers import print_diff from viz.serve import reconstruct_graph, uop_to_json, load_kernels +from viz.spec import GraphRewriteMetadata def group_rewrites(kernels:List[GraphRewriteMetadata]): return {k:list(v) for k,v in itertools.groupby(kernels, lambda x:x.loc)} @@ -22,16 +20,9 @@ def tearDown(self) -> None: from tinygrad.ops import contexts if not getenv("VIZ"): contexts.clear() - def assert_valid_ctx(self, contexts:List[TrackedRewriteContext]): + def assert_valid_ctx(self, contexts:List[Tuple[Any,List[TrackedRewriteContext]]]): assert len(contexts) != 0 - for i,ctx in enumerate(contexts): - try: graphs,_,_ = reconstruct_graph(ctx.sink, ctx.rewrites) - except Exception as e: - print(colored(f"failed to create graph for ctx {i}", "red")) - raise e - for j,(x,y) in enumerate(zip(graphs, graphs[1:])): - if x.key == y.key: - raise AssertionError(f"failed to generate the correct diff at rewrite {j} ctx {i}") + load_kernels(contexts) def assert_valid_graph(self, t): contexts.clear() @@ -52,8 +43,8 @@ def test_ctx_groups(self): list(lower_schedule(schedule2)) with Context(TRACK_MATCH_STATS=0): ret = list(load_kernels(contexts).values()) assert len(ret) == 3 - assert all(len([x for x,_ in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:]) - assert all(len([x for x,_ in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:]) + assert all(len([x for x,_,_ in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:]) + assert all(len([x for x,_,_ in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:]) def test_gemm_diff(self): x = Tensor.empty(64, 64).realize() @@ -72,7 +63,7 @@ def test_removed_node(self): ]) ret = graph_rewrite(sink, pm) if DEBUG >= 4: print_diff(sink, ret) - graphs,_,_ = reconstruct_graph(contexts[0].sink, contexts[0].rewrites) + graphs,_,_ = reconstruct_graph(contexts[0][1][0]) assert graphs[-1].key == ret.key self.assert_valid_ctx(contexts) @@ -104,16 +95,7 @@ def test_devectorize_viz(self): new_sink = graph_rewrite(sink, pm) if DEBUG >= 4: print_diff(sink, new_sink, unified=0) self.assert_valid_ctx(contexts) - assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for ctx in contexts) - - @unittest.skipIf(CI, "slow, it's generating diffs for 36202 rules") - def test_fuzz_resnet(self): - mdl = ResNet50() - img = Tensor.empty(64, 3, 224, 224) - out = mdl(img) - sched = out.schedule() - list(lower_schedule(sched)) - self.assert_valid_ctx(contexts) + assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for _,ctxs in contexts for ctx in ctxs) def test_no_ctx(self): simple_pm = PatternMatcher([(UPat(UOps.CONST), lambda:True)])