From 72ed66205d1c8e56b95e540ac8ca121ccd324e4e Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 17 Oct 2024 19:02:20 -0400 Subject: [PATCH] enable test_resnet_half (#7141) already worked so just fixed the test --- test/unit/test_uop_symbolic.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 1a883acb004c..9f0ee23a7b45 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -667,7 +667,6 @@ def test_substitute(self): """ class TestSymbolicRealWorld(unittest.TestCase): - @unittest.expectedFailure def test_resnet_half(self): gidx0 = Variable("gidx0", 0, 3) gidx1 = Variable("gidx1", 0, 127) @@ -676,10 +675,12 @@ def test_resnet_half(self): lidx4 = Variable("lidx4", 0, 1) lidx5 = Variable("lidx5", 0, 15) - idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3) - print(idx.render()) + idx:UOp = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3) + idx = graph_rewrite(idx, sym) + # print(idx.render()) # NOTE: this used to have 13,151,129,600 in the output which is out of int32 range. - assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)" + assert idx.render() == \ + "((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)" class TestBounds(unittest.TestCase): def test_unrolled_arange(self):