Skip to content

Commit

Permalink
enable test_resnet_half (tinygrad#7141)
Browse files Browse the repository at this point in the history
already worked so just fixed the test
  • Loading branch information
chenyuxyz authored Oct 17, 2024
1 parent 211d975 commit 72ed662
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions test/unit/test_uop_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,6 @@ def test_substitute(self):
"""

class TestSymbolicRealWorld(unittest.TestCase):
@unittest.expectedFailure
def test_resnet_half(self):
gidx0 = Variable("gidx0", 0, 3)
gidx1 = Variable("gidx1", 0, 127)
Expand All @@ -676,10 +675,12 @@ def test_resnet_half(self):
lidx4 = Variable("lidx4", 0, 1)
lidx5 = Variable("lidx5", 0, 15)

idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
print(idx.render())
idx:UOp = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
idx = graph_rewrite(idx, sym)
# print(idx.render())
# NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
assert idx.render() == \
"((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)"

class TestBounds(unittest.TestCase):
def test_unrolled_arange(self):
Expand Down

0 comments on commit 72ed662

Please sign in to comment.