Skip to content

Commit

Permalink
fix some sigmoid extreme
Browse files Browse the repository at this point in the history
quite brittle... the problem is it has 3 terms and mul might have bad order
  • Loading branch information
chenyuxyz committed Dec 14, 2024
1 parent bcd7ea6 commit 2b2558a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 0 additions & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,14 +807,12 @@ def test_softsign_exact(self):
def test_sigmoid(self):
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid)
helper_test_op([()], torch.sigmoid, Tensor.sigmoid)
@unittest.skip("TODO: fix sigmoid stability")
def test_sigmoid_extreme(self):
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=400)
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-400, high=-300)
def test_hardsigmoid(self):
helper_test_op([(45,65)], torch.nn.functional.hardsigmoid, Tensor.hardsigmoid)
helper_test_op([()], torch.nn.functional.hardsigmoid, Tensor.hardsigmoid)
@unittest.skip("TODO: fix sigmoid stability")
def test_hardsigmoid_extreme(self):
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=400)
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-400, high=-300)
Expand All @@ -835,14 +833,12 @@ def test_erf(self):

def test_gelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
@unittest.skip("TODO: fix sigmoid stability")
def test_gelu_extreme(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=300, high=400)
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=-400, high=-300)
def test_quick_gelu(self):
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
@unittest.skip("TODO: fix sigmoid stability")
def test_quick_gelu_extreme(self):
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=300, high=400)
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=-400, high=-300)
Expand Down
6 changes: 6 additions & 0 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def threefry2x32(x: UOp, key: UOp):

return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)

# ***** other math rewrite ****

def sigmoid_like(x:UOp, y:UOp): return (t:=(1/(x+1))) * (1-t) * y

# ***** main rewriter *****

def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extra=None,vec=None,ne=None,
Expand Down Expand Up @@ -308,6 +312,8 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL} else (x,) for x in root.src)), root.arg)
if any(x.op in {Ops.SINK, Ops.UNROLL} for x in root.src) else None),
# stable sigmoid
(UPat.var("x")*(((UPat.var("x")+1)*(UPat.var("x")+1)).reciprocal()*UPat.var("y")), lambda x,y: sigmoid_like(x,y)),
])

# *** uop expander ***
Expand Down

0 comments on commit 2b2558a

Please sign in to comment.