Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU PJRT: reduce max of NaNs different if the value is a constant or if the value comes from a parameter #21461

Open
janpfeifer opened this issue Jan 15, 2025 · 2 comments
Assignees

Comments

@janpfeifer
Copy link
Contributor

janpfeifer commented Jan 15, 2025

I know, it's an odd bug, and it only happens in the CPU PJRT (compiled last month, see commit hash). Using the CUDA PJRT it works as expected. I found out when writing some edge case unit tests.

Since I maintain the Go wrapper to XlaBuilder (see GoMLX), here are the two versions of the simple to read Go code, one using constant and the other using as input a parameter -- it should be easy to read. I include in the bottom the HloModuleProto (not as readable) for both cases:

	// It works if input is passed as a constant:
	{
		got := ExecOnce(backend, func(g *Graph) *Node {
			return ReduceMax(Const(g, []float64{math.NaN(), 1}))
		})
                // got is NaN as expected
	}

Now the buggy version, which returns 1 (it should return NaN), when passing the NaN as a parameter -- so it cannot be optimized away as a constant:

	got := ExecOnce(backend, func(x *Node) *Node {
		return ReduceMax(x)
	}, []float64{math.NaN(), 1})
	// got is 1 !?!?!

Here is the resulting HLOModuleProto of the first version (input is a constant), using float32, where
ReduceMax({NaN, 1}) == NaN as expected.

name: "TestReduce-ReduceMax with NaN as constant.8"
entry_computation_name: "TestReduce-ReduceMax with NaN as constant.8"
computations {
  name: "#_ReduceMaxType_Float32.3"
  instructions {
    name: "lhs.4"
    opcode: "parameter"
    shape {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    metadata {
    }
    id: 4
    frontend_attributes {
    }
  }
  instructions {
    name: "rhs.5"
    opcode: "parameter"
    shape {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    metadata {
    }
    parameter_number: 1
    id: 5
    frontend_attributes {
    }
  }
  instructions {
    name: "maximum.6"
    opcode: "maximum"
    shape {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    metadata {
    }
    id: 6
    operand_ids: 4
    operand_ids: 5
    frontend_attributes {
    }
  }
  program_shape {
    parameters {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    parameters {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    result {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    parameter_names: "lhs"
    parameter_names: "rhs"
  }
  id: 3
  root_id: 6
}
computations {
  name: "TestReduce-ReduceMax with NaN as constant.8"
  instructions {
    name: "constant.1"
    opcode: "constant"
    shape {
      element_type: F32
      dimensions: 2
      layout {
        minor_to_major: 0
        tail_padding_alignment_in_elements: 1
      }
      is_dynamic_dimension: false
    }
    metadata {
    }
    literal {
      shape {
        element_type: F32
        dimensions: 2
        layout {
          minor_to_major: 0
          tail_padding_alignment_in_elements: 1
        }
        is_dynamic_dimension: false
      }
      f32s: nan
      f32s: 1
    }
    id: 1
    frontend_attributes {
    }
  }
  instructions {
    name: "constant.2"
    opcode: "constant"
    shape {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    metadata {
    }
    literal {
      shape {
        element_type: F32
        layout {
          tail_padding_alignment_in_elements: 1
        }
      }
      f32s: -inf
    }
    id: 2
    frontend_attributes {
    }
  }
  instructions {
    name: "reduce.7"
    opcode: "reduce"
    shape {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    metadata {
    }
    dimensions: 0
    id: 7
    operand_ids: 1
    operand_ids: 2
    called_computation_ids: 3
    frontend_attributes {
    }
  }
  program_shape {
    result {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
  }
  id: 8
  root_id: 7
}
host_program_shape {
  result {
    element_type: F32
    layout {
      tail_padding_alignment_in_elements: 1
    }
  }
}
id: 8
entry_computation_id: 8

And here is the HLO version where the input is a parameter, so the const cannot be optimized away, and where it returns 1 (that is, ReduceMax({NaN, 1}) == 1!?)

name: "TestReduce-ReduceMax with NaN as parameter.8"
entry_computation_name: "TestReduce-ReduceMax with NaN as parameter.8"
computations {
  name: "#_ReduceMaxType_Float32.3"
  instructions {
    name: "lhs.4"
    opcode: "parameter"
    shape {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    metadata {
    }
    id: 4
    frontend_attributes {
    }
  }
  instructions {
    name: "rhs.5"
    opcode: "parameter"
    shape {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    metadata {
    }
    parameter_number: 1
    id: 5
    frontend_attributes {
    }
  }
  instructions {
    name: "maximum.6"
    opcode: "maximum"
    shape {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    metadata {
    }
    id: 6
    operand_ids: 4
    operand_ids: 5
    frontend_attributes {
    }
  }
  program_shape {
    parameters {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    parameters {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    result {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    parameter_names: "lhs"
    parameter_names: "rhs"
  }
  id: 3
  root_id: 6
}
computations {
  name: "TestReduce-ReduceMax with NaN as parameter.8"
  instructions {
    name: "x.1"
    opcode: "parameter"
    shape {
      element_type: F32
      dimensions: 2
      layout {
        minor_to_major: 0
        tail_padding_alignment_in_elements: 1
      }
      is_dynamic_dimension: false
    }
    metadata {
    }
    id: 1
    frontend_attributes {
    }
  }
  instructions {
    name: "constant.2"
    opcode: "constant"
    shape {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    metadata {
    }
    literal {
      shape {
        element_type: F32
        layout {
          tail_padding_alignment_in_elements: 1
        }
      }
      f32s: -inf
    }
    id: 2
    frontend_attributes {
    }
  }
  instructions {
    name: "reduce.7"
    opcode: "reduce"
    shape {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    metadata {
    }
    dimensions: 0
    id: 7
    operand_ids: 1
    operand_ids: 2
    called_computation_ids: 3
    frontend_attributes {
    }
  }
  program_shape {
    parameters {
      element_type: F32
      dimensions: 2
      layout {
        minor_to_major: 0
        tail_padding_alignment_in_elements: 1
      }
      is_dynamic_dimension: false
    }
    result {
      element_type: F32
      layout {
        tail_padding_alignment_in_elements: 1
      }
    }
    parameter_names: "x"
  }
  id: 8
  root_id: 7
}
host_program_shape {
  parameters {
    element_type: F32
    dimensions: 2
    layout {
      minor_to_major: 0
      tail_padding_alignment_in_elements: 1
    }
    is_dynamic_dimension: false
  }
  result {
    element_type: F32
    layout {
      tail_padding_alignment_in_elements: 1
    }
  }
  parameter_names: "x"
}
id: 8
entry_computation_id: 8
@mooskagh
Copy link
Member

Two protobufs mentioned in the bug as HLO snippets:

HloModule TestReduce-ReduceMax_with_NaN_as_constant.8, entry_computation_layout={()->f32[]}

a_ReduceMaxType_Float32.3 {
  lhs.4 = f32[] parameter(0)
  rhs.5 = f32[] parameter(1)
  ROOT maximum.6 = f32[] maximum(lhs.4, rhs.5)
}

ENTRY TestReduce-ReduceMax_with_NaN_as_constant.8 {
  constant.1 = f32[2]{0} constant({nan, 1})
  constant.2 = f32[] constant(-inf)
  ROOT reduce.7 = f32[] reduce(constant.1, constant.2), dimensions={0}, to_apply=a_ReduceMaxType_Float32.3
}
HloModule TestReduce-ReduceMax_with_NaN_as_parameter.8, entry_computation_layout={(f32[2]{0})->f32[]}

a_ReduceMaxType_Float32.3 {
  lhs.4 = f32[] parameter(0)
  rhs.5 = f32[] parameter(1)
  ROOT maximum.6 = f32[] maximum(lhs.4, rhs.5)
}

ENTRY TestReduce-ReduceMax_with_NaN_as_parameter.8 {
  x.1 = f32[2]{0} parameter(0)
  constant.2 = f32[] constant(-inf)
  ROOT reduce.7 = f32[] reduce(x.1, constant.2), dimensions={0}, to_apply=a_ReduceMaxType_Float32.3
}

@janpfeifer
Copy link
Contributor Author

as HLO snippets:

Thanks @mooskagh !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants