Skip to content

Commit

Permalink
Update the jump detection unittest to check
Browse files Browse the repository at this point in the history
if jumps are correctly identified
This checks if jumps are correctly identified, resolving some lingering
test issues from #227
  • Loading branch information
WilliamJamieson committed Nov 14, 2023
1 parent 36d9d14 commit d758196
Showing 1 changed file with 24 additions and 43 deletions.
67 changes: 24 additions & 43 deletions tests/test_jump_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
# that the random process will "accidentally" generate a set of data, which
# can trigger jump detection. This makes it easier to cleanly test jump
# detection is doing what we expect.
FLUX = 100
READ_NOISE = np.float32(5)
FLUX = 10
READ_NOISE = np.float32(20)

# Set a value for jumps which makes them obvious relative to the normal flux
JUMP_VALUE = 1_000
Expand All @@ -32,7 +32,7 @@
# across all tests to make it easier to isolate the effects of something using
# multiple tests.
N_PIXELS = 100_000
CHI2_TOL = 0.03
CHI2_TOL = 0.3
GOOD_PROB = 0.7


Expand Down Expand Up @@ -353,8 +353,7 @@ def test_fit_ramps(detector_data, use_jump, use_dq):

chi2 = 0
for fit, use in zip(output.fits, okay):
if not use_dq and not use_jump:
##### The not use_jump makes this NOT test for false positives #####
if not use_dq:
# Check that the data generated does not generate any false positives
# for jumps as this data is reused for `test_find_jumps` below.
# This guarantees that all jumps detected in that test are the
Expand Down Expand Up @@ -468,10 +467,6 @@ def test_find_jumps(jump_data):
assert len(output.fits) == len(jump_reads) # sanity check that a fit/jump is set for every pixel

chi2 = 0
incorrect_too_few = 0
incorrect_too_many = 0
incorrect_does_not_capture = 0
incorrect_other = 0
for fit, jump_index, resultant_index in zip(output.fits, jump_reads, jump_resultants):
# Check that the jumps are detected correctly
if jump_index == 0:
Expand All @@ -485,51 +480,37 @@ def test_find_jumps(jump_data):
assert fit["index"][0]["start"] == 0
assert fit["index"][0]["end"] == len(read_pattern) - 1
else:
# There should be a single jump detected; however, this results in
# two resultants being excluded.
if resultant_index not in fit["jumps"]:
incorrect_does_not_capture += 1
continue
if len(fit["jumps"]) < 2:
incorrect_too_few += 1
continue
if len(fit["jumps"]) > 2:
incorrect_too_many += 1
continue

# The two resultants excluded should be adjacent
jump_correct = [
(jump in (resultant_index, resultant_index - 1, resultant_index + 1)) for jump in fit["jumps"]
]
if not all(jump_correct):
incorrect_other += 1
continue

# Check that the inserted jump is detected or if the jump occurs in the last resultant
# (there are some unresolved issues with this case)
assert resultant_index in fit["jumps"] or resultant_index == resultants.shape[0] - 1
# Because we do not have a data set with no false positives, we cannot run the below
# # Test the correct ramp indexes are recorded
# ramp_indices = []
# for ramp_index in fit['index']:
# # Note start/end of a ramp_index are inclusive meaning that end
# # is an index included in the ramp_index so the range is to end + 1
# new_indices = list(range(ramp_index["start"], ramp_index["end"] + 1))
# Test the correct ramp indexes are recorded

# Here we map out all of the ramps and make sure they are non-overlapping and that
# they do not overlap with the identified jumps
ramp_indices = []
for ramp_index in fit["index"]:
# Note start/end of a ramp_index are inclusive meaning that end
# is an index included in the ramp_index so the range is to end + 1
new_indices = list(range(ramp_index["start"], ramp_index["end"] + 1))

# # check that all the ramps are non-overlapping
# assert set(ramp_indices).isdisjoint(new_indices)
# check that all the ramps are non-overlapping
assert set(ramp_indices).isdisjoint(new_indices)

# ramp_indices.extend(new_indices)
ramp_indices.extend(new_indices)

# # check that no ramp_index is a jump
# assert set(ramp_indices).isdisjoint(fit['jumps'])
# check that no ramp_index is a jump
assert set(ramp_indices).isdisjoint(fit["jumps"])

# # check that all resultant indices are either in a ramp or listed as a jump
# assert set(ramp_indices).union(fit['jumps']) == set(range(len(read_pattern)))
# check that all resultant indices are either in a ramp or listed as a jump
assert set(ramp_indices).union(fit["jumps"]) == set(range(len(read_pattern)))

# Compute the chi2 for the fit and add it to a running "total chi2"
total_var = fit["average"]["read_var"] + fit["average"]["poisson_var"]
chi2 += (fit["average"]["slope"] - FLUX) ** 2 / total_var

# Check that the average chi2 is ~1.
chi2 /= N_PIXELS - incorrect_too_few - incorrect_too_many - incorrect_does_not_capture - incorrect_other
chi2 /= N_PIXELS
assert np.abs(chi2 - 1) < CHI2_TOL


Expand Down

0 comments on commit d758196

Please sign in to comment.