You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I know, it's an odd bug, and it only happens in the CPU PJRT (compiled last month, see commit hash). Using the CUDA PJRT it works as expected. I found out when writing some edge case unit tests.
Since I maintain the Go wrapper to XlaBuilder (see GoMLX), here are the two versions of the simple to read Go code, one using constant and the other using as input a parameter -- it should be easy to read. I include in the bottom the HloModuleProto (not as readable) for both cases:
// It works if input is passed as a constant:
{
got:=ExecOnce(backend, func(g*Graph) *Node {
returnReduceMax(Const(g, []float64{math.NaN(), 1}))
})
// got is NaN as expected
}
Now the buggy version, which returns 1 (it should return NaN), when passing the NaN as a parameter -- so it cannot be optimized away as a constant:
And here is the HLO version where the input is a parameter, so the const cannot be optimized away, and where it returns 1 (that is, ReduceMax({NaN, 1}) == 1!?)
I know, it's an odd bug, and it only happens in the CPU PJRT (compiled last month, see commit hash). Using the CUDA PJRT it works as expected. I found out when writing some edge case unit tests.
Since I maintain the Go wrapper to XlaBuilder (see GoMLX), here are the two versions of the simple to read Go code, one using constant and the other using as input a parameter -- it should be easy to read. I include in the bottom the
HloModuleProto
(not as readable) for both cases:Now the buggy version, which returns 1 (it should return
NaN
), when passing theNaN
as a parameter -- so it cannot be optimized away as a constant:Here is the resulting
HLOModuleProto
of the first version (input is a constant), usingfloat32
, whereReduceMax({NaN, 1}) == NaN
as expected.And here is the HLO version where the input is a parameter, so the
const
cannot be optimized away, and where it returns 1 (that is,ReduceMax({NaN, 1}) == 1!?
)The text was updated successfully, but these errors were encountered: