Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trigger Readout Window #763

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions NuRadioMC/simulation/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import NuRadioReco.modules.efieldToVoltageConverter
import NuRadioReco.modules.channelAddCableDelay
import NuRadioReco.modules.channelResampler
import NuRadioReco.modules.triggerTimeAdjuster
import NuRadioReco.modules.channelReadoutWindowCutter
from NuRadioReco.detector import detector
import NuRadioReco.framework.sim_station
import NuRadioReco.framework.electric_field
Expand Down Expand Up @@ -59,7 +59,7 @@
channelGenericNoiseAdder = NuRadioReco.modules.channelGenericNoiseAdder.channelGenericNoiseAdder()
channelResampler = NuRadioReco.modules.channelResampler.channelResampler()
eventWriter = NuRadioReco.modules.io.eventWriter.eventWriter()
triggerTimeAdjuster = NuRadioReco.modules.triggerTimeAdjuster.triggerTimeAdjuster()
channelReadoutWindowCutter = NuRadioReco.modules.channelReadoutWindowCutter.channelReadoutWindowCutter()

def merge_config(user, default):
"""
Expand Down Expand Up @@ -1519,7 +1519,7 @@ def run(self):
if not evt.get_station().has_triggered():
continue

triggerTimeAdjuster.run(evt, station, self._det)
channelReadoutWindowCutter.run(evt, station, self._det)
evt_group_triggered = True
output_buffer[station_id][evt.get_id()] = evt
# end event loop
Expand Down
2 changes: 2 additions & 0 deletions NuRadioReco/modules/channelLengthAdjuster.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class channelLengthAdjuster:
def __init__(self):
self.number_of_samples = None
self.offset = None
logger.warning("In most cases it is advisable to run a trigger module and use the channelReadoutWindowCutter module to cut the traces to the readout window \
instead of this simple module.")
cg-laser marked this conversation as resolved.
Show resolved Hide resolved
self.begin(())

def begin(self, number_of_samples=256, offset=50):
Expand Down
137 changes: 137 additions & 0 deletions NuRadioReco/modules/channelReadoutWindowCutter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from NuRadioReco.modules.base.module import register_run
import numpy as np
import logging
from NuRadioReco.utilities import units

logger = logging.getLogger('NuRadioReco.channelReadoutWindowCutter')


class channelReadoutWindowCutter:
"""
Modifies channel traces to simulate the effects of the trigger

The trace is cut to the length defined in the detector description relative to the trigger time.
If no trigger exists, nothing is done.
"""

def __init__(self, log_level=logging.NOTSET):
logger.setLevel(log_level)
self.__sampling_rate_warning_issued = False
self.begin()

def begin(self):
pass

@register_run()
def run(self, event, station, detector):
"""
Cuts the traces to the readout window defined in the trigger.

If multiple triggers exist, the primary trigger is used. If multiple
primary triggers exist, an error is raised.
If no primary trigger exists, the trigger with the earliest trigger time
is defined as the primary trigger and used to set the readout windows.

Parameters
----------
event: `NuRadioReco.framework.event.Event`

station: `NuRadioReco.framework.base_station.Station`

detector: `NuRadioReco.detector.detector.Detector`
"""
counter = 0
for i, (name, instance, kwargs) in enumerate(event.iter_modules(station.get_id())):
if name == 'channelReadoutWindowCutter':
counter += 1
if counter > 1:
logger.warning('channelReadoutWindowCutter was called twice. '
'This is likely a mistake. The module will not be applied again.')
return 0


# determine which trigger to use
# if no primary trigger exists, use the trigger with the earliest trigger time
trigger = station.get_primary_trigger()
if trigger is None: # no primary trigger found
logger.debug('No primary trigger found. Using the trigger with the earliest trigger time.')
trigger = station.get_first_trigger()
if trigger is not None:
logger.info(f"setting trigger {trigger.get_name()} primary because it triggered first")
trigger.set_primary(True)

if trigger is None:
logger.info('No trigger found! Channel timings will not be changed.')
return

if trigger.has_triggered():
trigger_time = trigger.get_trigger_time()
for channel in station.iter_channels():
trigger_time_channel = trigger_time - channel.get_trace_start_time()
# if trigger_time_channel == 0:
# logger.warning(f"the trigger time is equal to the trace start time for channel {channel.get_id()}. This is likely because this module was already run on this station. The trace will not be changed.")
# continue

trace = channel.get_trace()
trace_length = len(trace)
detector_sampling_rate = detector.get_sampling_frequency(station.get_id(), channel.get_id())
sampling_rate = channel.get_sampling_rate()
self.__check_sampling_rates(detector_sampling_rate, sampling_rate)

# this should ensure that 1) the number of samples is even and
# 2) resampling to the detector sampling rate results in the correct number of samples
# (note that 2) can only be guaranteed if the detector sampling rate is lower than the
# current sampling rate)
number_of_samples = int(
2 * np.ceil(
detector.get_number_of_samples(station.get_id(), channel.get_id()) / 2
* sampling_rate / detector_sampling_rate
))

if number_of_samples > trace.shape[0]:
logger.error("Input has fewer samples than desired output. Channels has only {} samples but {} samples are requested.".format(
trace.shape[0], number_of_samples))
raise AttributeError
else:
trigger_time_sample = int(np.round(trigger_time_channel * sampling_rate))
# logger.debug(f"channel {channel.get_id()}: trace_start_time = {channel.get_trace_start_time():.1f}ns, trigger time channel {trigger_time_channel/units.ns:.1f}ns, trigger time sample = {trigger_time_sample}")
channel_id = channel.get_id()
pre_trigger_time = trigger.get_pre_trigger_time_channel(channel_id)
samples_before_trigger = int(pre_trigger_time * sampling_rate)
cut_samples_beginning = 0
if(samples_before_trigger <= trigger_time_sample):
cut_samples_beginning = trigger_time_sample - samples_before_trigger
roll_by = 0
if(cut_samples_beginning + number_of_samples > trace_length):
logger.warning("trigger time is sample {} but total trace length is only {} samples (requested trace length is {} with an offest of {} before trigger). To achieve desired configuration, trace will be rolled".format(
trigger_time_sample, trace_length, number_of_samples, samples_before_trigger))
roll_by = cut_samples_beginning + number_of_samples - trace_length # roll_by is positive
trace = np.roll(trace, -1 * roll_by)
cut_samples_beginning -= roll_by
rel_station_time_samples = cut_samples_beginning + roll_by
elif(samples_before_trigger > trigger_time_sample):
roll_by = -trigger_time_sample + samples_before_trigger
logger.warning(f"trigger time is before 'trigger offset window' (requested samples before trigger = {samples_before_trigger}," \
f"trigger time sample = {trigger_time_sample}), the trace needs to be rolled by {roll_by} samples first" \
f" = {roll_by / sampling_rate/units.ns:.2f}ns")
trace = np.roll(trace, roll_by)

# shift trace to be in the correct location for cutting
trace = trace[cut_samples_beginning:(number_of_samples + cut_samples_beginning)]
channel.set_trace(trace, channel.get_sampling_rate())
channel.set_trace_start_time(trigger_time - pre_trigger_time)
# channel.set_trace_start_time(channel.get_trace_start_time() + rel_station_time_samples / channel.get_sampling_rate())
# logger.debug(f"setting trace start time to {channel.get_trace_start_time() + rel_station_time_samples / channel.get_sampling_rate():.0f} = {channel.get_trace_start_time():.0f} + {rel_station_time_samples / channel.get_sampling_rate():.0f}")



def __check_sampling_rates(self, detector_sampling_rate, channel_sampling_rate):
if not self.__sampling_rate_warning_issued: # we only issue this warning once
if not np.isclose(detector_sampling_rate, channel_sampling_rate):
logger.warning(
'triggerTimeAdjuster was called, but the channel sampling rate '
f'({channel_sampling_rate/units.GHz:.3f} GHz) is not equal to '
f'the target detector sampling rate ({detector_sampling_rate/units.GHz:.3f} GHz). '
'Traces may not have the correct trace length after resampling.'
)
self.__sampling_rate_warning_issued = True
Comment on lines +128 to +137
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this warning at all? We ensure that we have the correct number of samples with:


                # this should ensure that 1) the number of samples is even and
                # 2) resampling to the detector sampling rate results in the correct number of samples
                # (note that 2) can only be guaranteed if the detector sampling rate is lower than the
                # current sampling rate)
                number_of_samples = int(
                    2 * np.ceil(
                        detector.get_number_of_samples(station.get_id(), channel.get_id()) / 2
                        * sampling_rate / detector_sampling_rate
                    ))

Maybe we can only check that detector_sampling_rate <= channel_sampling_rate?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when I implemented this I thought about it sufficiently to realize that if sampling_rate < detector_sampling_rate we cannot guarantee the correct number of samples (with the way resampling is currently implemented); I am not actually sure if the correct number of samples is guaranteed in all cases if sampling_rate > detector_sampling_rate; my feeling was that this is only true in most cases which is why I left the warning as it is. If you can work out the maths and figure out that this is wrong and we don't need the warning we can change this : )

6 changes: 6 additions & 0 deletions NuRadioReco/modules/triggerTimeAdjuster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def __init__(self, log_level=logging.NOTSET):
logger.setLevel(log_level)
self.__sampling_rate_warning_issued = False
self.begin()
logger.warning(f"triggerTimeAdjuster is deprecated and will be removed soon. In most cased you can safely delete the application\
of this module as it is automatically applied in NuRadioMC simulations. If you really need to use this module, \
please use the channelReadoutWindowCutter module instead.")
cg-laser marked this conversation as resolved.
Show resolved Hide resolved

def begin(self):
pass
Expand Down Expand Up @@ -76,6 +79,9 @@ def run(self, event, station, detector, mode='sim_to_data'):
If the ``trigger_name`` was specified in the ``begin`` function,
only this trigger is considered.
"""
logger.warning(f"triggerTimeAdjuster is deprecated and will be removed soon. In most cased you can safely delete the application\
of this module as it is automatically applied in NuRadioMC simulations. If you really need to use this module, \
please use the channelReadoutWindowCutter module instead.")
cg-laser marked this conversation as resolved.
Show resolved Hide resolved
counter = 0
for i, (name, instance, kwargs) in enumerate(event.iter_modules(station.get_id())):
if name == 'triggerTimeAdjuster':
Expand Down
Loading