Skip to content

Commit

Permalink
just sigmoid gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed Dec 14, 2024
1 parent 0344c68 commit 8f5455d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
4 changes: 4 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,10 @@ def test_sigmoid(self):
def test_sigmoid_extreme(self):
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=400)
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-400, high=-300)
x = Tensor([300.0])
self.assertAlmostEqual(x.sigmoid()[0].gradient(x)[0].item(), 0.0)
x = Tensor([-300.0])
self.assertAlmostEqual(x.sigmoid()[0].gradient(x)[0].item(), 0.0)
def test_hardsigmoid(self):
helper_test_op([(45,65)], torch.nn.functional.hardsigmoid, Tensor.hardsigmoid)
helper_test_op([()], torch.nn.functional.hardsigmoid, Tensor.hardsigmoid)
Expand Down
1 change: 1 addition & 0 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL} else (x,) for x in root.src)), root.arg)
if any(x.op in {Ops.SINK, Ops.UNROLL} for x in root.src) else None),
# stable sigmoid
(UPat.var("x")*(((UPat.var("x")+1)*(UPat.var("x")+1)).reciprocal()), lambda x: sigmoid_like(x, x.const_like(1))),
(UPat.var("x")*(((UPat.var("x")+1)*(UPat.var("x")+1)).reciprocal()*UPat.var("y")), sigmoid_like),
(UPat.var("x")*(((UPat.var("x")+1)*(UPat.var("x")+1)*(UPat.var("x")+1)).reciprocal()), lambda x: sigmoid_like(x, (x+1).reciprocal())),
])
Expand Down

0 comments on commit 8f5455d

Please sign in to comment.