Skip to content

Commit

Permalink
Fix shape annotations for sigmoid-bce loss.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553113646
  • Loading branch information
The kauldron Authors committed Aug 2, 2023
1 parent d4b3d8c commit e51f224
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion kauldron/losses/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,7 @@ class SigmoidBinaryCrossEntropy(base.Loss):
labels: Key = "batch.label"

@typechecked
def get_values(self, logits: Float["*a"], labels: Float["*a"]) -> Float["*a"]:
def get_values(
self, logits: Float["*a n"], labels: Int["*a n"]
) -> Float["*a 1"]:
return optax.sigmoid_binary_cross_entropy(logits, labels)

0 comments on commit e51f224

Please sign in to comment.