-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor NuRadioReco/modules/efieldToVoltageConverter.py #707
Changes from all commits
f325d92
ef07ef3
33e4363
cf6f2ff
b444f69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,16 @@ | ||
import numpy as np | ||
import time | ||
import logging | ||
import copy | ||
|
||
import NuRadioReco.framework.channel | ||
import NuRadioReco.framework.base_trace | ||
from NuRadioReco.modules.base.module import register_run | ||
from NuRadioReco.detector import antennapattern | ||
from NuRadioReco.utilities import geometryUtilities as geo_utl | ||
from NuRadioReco.utilities import units, fft | ||
from NuRadioReco.utilities import ice | ||
from NuRadioReco.utilities import trace_utilities | ||
from NuRadioReco.framework.parameters import electricFieldParameters as efp | ||
from NuRadioReco.framework.parameters import stationParameters as stnp | ||
import copy | ||
|
||
from NuRadioReco.modules.base.module import register_run | ||
from NuRadioReco.detector import antennapattern | ||
from NuRadioReco.utilities import units, fft, ice, trace_utilities, geometryUtilities as geo_utl | ||
|
||
|
||
class efieldToVoltageConverter(): | ||
|
@@ -35,9 +34,10 @@ def __init__(self, log_level=logging.NOTSET): | |
self.__post_pulse_time = None | ||
self.__max_upsampling_factor = None | ||
self.antenna_provider = None | ||
self.begin() | ||
self.logger = logging.getLogger('NuRadioReco.efieldToVoltageConverter') | ||
self.logger.setLevel(log_level) | ||
self.begin() | ||
|
||
|
||
def begin(self, debug=False, uncertainty=None, | ||
time_resolution=0.1 * units.ns, | ||
|
@@ -74,20 +74,17 @@ def begin(self, debug=False, uncertainty=None, | |
self.__pre_pulse_time = pre_pulse_time | ||
self.__post_pulse_time = post_pulse_time | ||
self.__max_upsampling_factor = 5000 | ||
if uncertainty is None: | ||
self.__uncertainty = {} | ||
else: | ||
self.__uncertainty = uncertainty | ||
|
||
# some uncertainties are systematic, fix them here | ||
if('sys_dx' in self.__uncertainty): | ||
self.__uncertainty['sys_dx'] = np.random.normal(0, self.__uncertainty['sys_dx']) | ||
if('sys_dy' in self.__uncertainty): | ||
self.__uncertainty['sys_dy'] = np.random.normal(0, self.__uncertainty['sys_dy']) | ||
if('sys_dz' in self.__uncertainty): | ||
self.__uncertainty['sys_dz'] = np.random.normal(0, self.__uncertainty['sys_dz']) | ||
if('sys_amp' in self.__uncertainty): | ||
for iCh in self.__uncertainty['sys_amp'].keys(): | ||
self.__uncertainty = uncertainty or {} | ||
for key in ['sys_dx', 'sys_dy', 'sys_dz']: | ||
if key in self.__uncertainty: | ||
self.__uncertainty[key] = np.random.normal(0, self.__uncertainty[key]) | ||
|
||
if 'sys_amp' in self.__uncertainty: | ||
for iCh in self.__uncertainty['sys_amp']: | ||
self.__uncertainty['sys_amp'][iCh] = np.random.normal(1, self.__uncertainty['sys_amp'][iCh]) | ||
|
||
self.antenna_provider = antennapattern.AntennaPatternProvider() | ||
|
||
@register_run() | ||
|
@@ -96,43 +93,35 @@ def run(self, evt, station, det, channel_ids=None): | |
|
||
# access simulated efield and high level parameters | ||
sim_station = station.get_sim_station() | ||
if(len(sim_station.get_electric_fields()) == 0): | ||
raise LookupError(f"station {station.get_id()} has no efields") | ||
sim_station_id = sim_station.get_id() | ||
if len(sim_station.get_electric_fields()) == 0: | ||
raise LookupError(f"station {station.get_id()} has no efields") | ||
|
||
# first we determine the trace start time of all channels and correct | ||
# for different cable delays | ||
times_min = [] | ||
times_max = [] | ||
if channel_ids is None: | ||
channel_ids = det.get_channel_ids(sim_station_id) | ||
|
||
for channel_id in channel_ids: | ||
for electric_field in sim_station.get_electric_fields_for_channels([channel_id]): | ||
time_resolution = 1. / electric_field.get_sampling_rate() | ||
cab_delay = det.get_cable_delay(sim_station_id, channel_id) | ||
t0 = electric_field.get_trace_start_time() + cab_delay | ||
|
||
# if we have a cosmic ray event, the different signal travel time to the antennas has to be taken into account | ||
if sim_station.is_cosmic_ray(): | ||
site = det.get_site(sim_station_id) | ||
antenna_position = det.get_relative_position(sim_station_id, channel_id) - electric_field.get_position() | ||
if sim_station.get_parameter(stnp.zenith) > 90 * units.deg: # signal is coming from below, so we take IOR of ice | ||
index_of_refraction = ice.get_refractive_index(antenna_position[2], site) | ||
else: # signal is coming from above, so we take IOR of air | ||
index_of_refraction = ice.get_refractive_index(1, site) | ||
# For cosmic ray events, we only have one electric field for all channels, so we have to account | ||
# for the difference in signal travel between channels. IMPORTANT: This is only accurate | ||
# if all channels have the same z coordinate | ||
travel_time_shift = geo_utl.get_time_delay_from_direction( | ||
sim_station.get_parameter(stnp.zenith), | ||
sim_station.get_parameter(stnp.azimuth), | ||
antenna_position, | ||
index_of_refraction | ||
) | ||
travel_time_shift = calculate_time_shift_for_cosmic_ray(det, sim_station, electric_field, channel_id) | ||
t0 += travel_time_shift | ||
if(not np.isnan(t0)): # trace start time is None if no ray tracing solution was found and channel contains only zeros | ||
|
||
if not np.isnan(t0): | ||
# trace start time is None if no ray tracing solution was found and channel contains only zeros | ||
times_min.append(t0) | ||
times_max.append(t0 + electric_field.get_number_of_samples() / electric_field.get_sampling_rate()) | ||
self.logger.debug("trace start time {}, cab_delty {}, tracelength {}".format(electric_field.get_trace_start_time(), cab_delay, electric_field.get_number_of_samples() / electric_field.get_sampling_rate())) | ||
self.logger.debug("trace start time {}, cab_delty {}, tracelength {}".format( | ||
electric_field.get_trace_start_time(), cab_delay, | ||
electric_field.get_number_of_samples() / electric_field.get_sampling_rate())) | ||
|
||
# pad event times by pre/post pulse time | ||
times_min = np.array(times_min) - self.__pre_pulse_time | ||
|
@@ -142,7 +131,10 @@ def run(self, evt, station, det, channel_ids=None): | |
trace_length_samples = int(round(trace_length / time_resolution)) | ||
if trace_length_samples % 2 != 0: | ||
trace_length_samples += 1 | ||
self.logger.debug("smallest trace start time {:.1f}, largest trace time {:.1f} -> n_samples = {:d} {:.0f}ns)".format(times_min.min(), times_max.max(), trace_length_samples, trace_length / units.ns)) | ||
|
||
self.logger.debug( | ||
"smallest trace start time {:.1f}, largest trace time {:.1f} -> n_samples = {:d} {:.0f}ns)".format( | ||
times_min.min(), times_max.max(), trace_length_samples, trace_length / units.ns)) | ||
|
||
# loop over all channels | ||
for channel_id in channel_ids: | ||
|
@@ -153,44 +145,40 @@ def run(self, evt, station, det, channel_ids=None): | |
# and everything up in the time domain | ||
self.logger.debug('channel id {}'.format(channel_id)) | ||
channel = NuRadioReco.framework.channel.Channel(channel_id) | ||
channel_spectrum = None | ||
trace_object = None | ||
if(self.__debug): | ||
|
||
if self.__debug: | ||
from matplotlib import pyplot as plt | ||
fig, axes = plt.subplots(2, 1) | ||
|
||
channel_spectrum = None | ||
trace_object = None | ||
for electric_field in sim_station.get_electric_fields_for_channels([channel_id]): | ||
|
||
# all simulated channels have a different trace start time | ||
# in a measurement, all channels have the same physical start time | ||
# so we need to create one long trace that can hold all the different channel times | ||
# to achieve a good time resolution, we upsample the trace first. | ||
new_efield = NuRadioReco.framework.base_trace.BaseTrace() # create new data structure with new efield length | ||
new_efield.set_trace(copy.copy(electric_field.get_trace()), electric_field.get_sampling_rate()) | ||
new_trace = np.zeros((3, trace_length_samples)) | ||
|
||
# calculate the start bin | ||
if(not np.isnan(electric_field.get_trace_start_time())): | ||
if not np.isnan(electric_field.get_trace_start_time()): | ||
cab_delay = det.get_cable_delay(sim_station_id, channel_id) | ||
if sim_station.is_cosmic_ray(): | ||
site = det.get_site(sim_station_id) | ||
antenna_position = det.get_relative_position(sim_station_id, channel_id) - electric_field.get_position() | ||
if sim_station.get_parameter(stnp.zenith) > 90 * units.deg: # signal is coming from below, so we take IOR of ice | ||
index_of_refraction = ice.get_refractive_index(antenna_position[2], site) | ||
else: # signal is coming from above, so we take IOR of air | ||
index_of_refraction = ice.get_refractive_index(1, site) | ||
travel_time_shift = geo_utl.get_time_delay_from_direction( | ||
sim_station.get_parameter(stnp.zenith), | ||
sim_station.get_parameter(stnp.azimuth), | ||
antenna_position, | ||
index_of_refraction | ||
) | ||
start_time = electric_field.get_trace_start_time() + cab_delay - times_min.min() + travel_time_shift | ||
start_bin = int(round(start_time / time_resolution)) | ||
time_remainder = start_time - start_bin * time_resolution | ||
travel_time_shift = calculate_time_shift_for_cosmic_ray( | ||
det, sim_station, electric_field, channel_id) | ||
else: | ||
start_time = electric_field.get_trace_start_time() + cab_delay - times_min.min() | ||
start_bin = int(round(start_time / time_resolution)) | ||
time_remainder = start_time - start_bin * time_resolution | ||
self.logger.debug('channel {}, start time {:.1f} = bin {:d}, ray solution {}'.format(channel_id, electric_field.get_trace_start_time() + cab_delay, start_bin, electric_field[efp.ray_path_type])) | ||
travel_time_shift = 0 | ||
|
||
start_time = electric_field.get_trace_start_time() + cab_delay - times_min.min() + travel_time_shift | ||
start_bin = int(round(start_time / time_resolution)) | ||
|
||
# calculate error by using discret bins | ||
time_remainder = start_time - start_bin * time_resolution | ||
self.logger.debug('channel {}, start time {:.1f} = bin {:d}, ray solution {}'.format( | ||
channel_id, electric_field.get_trace_start_time() + cab_delay, start_bin, electric_field[efp.ray_path_type])) | ||
|
||
new_efield = NuRadioReco.framework.base_trace.BaseTrace() # create new data structure with new efield length | ||
new_efield.set_trace(copy.copy(electric_field.get_trace()), electric_field.get_sampling_rate()) | ||
new_efield.apply_time_shift(time_remainder) | ||
|
||
tr = new_efield.get_trace() | ||
|
@@ -201,28 +189,34 @@ def run(self, evt, station, det, channel_ids=None): | |
# ensure new efield does not extend beyond end of trace although this should not happen | ||
self.logger.warning("electric field trace extends beyond the end of the trace and will be cut.") | ||
stop_bin = np.shape(new_trace)[-1] | ||
tr = np.atleast_2d(tr)[:,:stop_bin-start_bin] | ||
tr = np.atleast_2d(tr)[:, :stop_bin-start_bin] | ||
|
||
if start_bin < 0: | ||
# ensure new efield does not extend beyond start of trace although this should not happen | ||
self.logger.warning("electric field trace extends beyond the beginning of the trace and will be cut.") | ||
tr = np.atleast_2d(tr)[:,-start_bin:] | ||
tr = np.atleast_2d(tr)[:, -start_bin:] | ||
start_bin = 0 | ||
|
||
new_trace[:, start_bin:stop_bin] = tr | ||
|
||
trace_object = NuRadioReco.framework.base_trace.BaseTrace() | ||
trace_object.set_trace(new_trace, 1. / time_resolution) | ||
if(self.__debug): | ||
|
||
if self.__debug: | ||
axes[0].plot(trace_object.get_times(), new_trace[1], label="eTheta {}".format(electric_field[efp.ray_path_type]), c='C0') | ||
axes[0].plot(trace_object.get_times(), new_trace[2], label="ePhi {}".format(electric_field[efp.ray_path_type]), c='C0', linestyle=':') | ||
axes[0].plot(electric_field.get_times(), electric_field.get_trace()[1], c='C1', linestyle='-', alpha=.5) | ||
axes[0].plot(electric_field.get_times(), electric_field.get_trace()[2], c='C1', linestyle=':', alpha=.5) | ||
|
||
ff = trace_object.get_frequencies() | ||
efield_fft = trace_object.get_frequency_spectrum() | ||
|
||
zenith = electric_field[efp.zenith] | ||
azimuth = electric_field[efp.azimuth] | ||
|
||
# get antenna pattern for current channel | ||
VEL = trace_utilities.get_efield_antenna_factor(sim_station, ff, [channel_id], det, zenith, azimuth, self.antenna_provider) | ||
VEL = trace_utilities.get_efield_antenna_factor( | ||
sim_station, ff, [channel_id], det, zenith, azimuth, self.antenna_provider) | ||
|
||
if VEL is None: # this can happen if there is not signal path to the antenna | ||
voltage_fft = np.zeros_like(efield_fft[1]) # set voltage trace to zeros | ||
|
@@ -234,31 +228,36 @@ def run(self, evt, station, det, channel_ids=None): | |
# Remove DC offset | ||
voltage_fft[np.where(ff < 5 * units.MHz)] = 0. | ||
|
||
if(self.__debug): | ||
axes[1].plot(trace_object.get_times(), fft.freq2time(voltage_fft, electric_field.get_sampling_rate()), label="{}, zen = {:.0f}deg".format(electric_field[efp.ray_path_type], zenith / units.deg)) | ||
if self.__debug: | ||
axes[1].plot( | ||
trace_object.get_times(), fft.freq2time(voltage_fft, electric_field.get_sampling_rate()), | ||
label="{}, zen = {:.0f}deg".format(electric_field[efp.ray_path_type], zenith / units.deg)) | ||
|
||
if('amp' in self.__uncertainty): | ||
if 'amp' in self.__uncertainty: | ||
voltage_fft *= np.random.normal(1, self.__uncertainty['amp'][channel_id]) | ||
if('sys_amp' in self.__uncertainty): | ||
|
||
if 'sys_amp' in self.__uncertainty: | ||
voltage_fft *= self.__uncertainty['sys_amp'][channel_id] | ||
|
||
if(channel_spectrum is None): | ||
if channel_spectrum is None: | ||
channel_spectrum = voltage_fft | ||
else: | ||
channel_spectrum += voltage_fft | ||
|
||
if(self.__debug): | ||
if self.__debug: | ||
axes[0].legend(loc='upper left') | ||
axes[1].legend(loc='upper left') | ||
plt.show() | ||
|
||
if trace_object is None: # this happens if don't have any efield for this channel | ||
# set the trace to zeros | ||
channel.set_trace(np.zeros(trace_length_samples), 1. / time_resolution) | ||
else: | ||
channel.set_frequency_spectrum(channel_spectrum, trace_object.get_sampling_rate()) | ||
channel.set_trace_start_time(times_min.min()) | ||
|
||
channel.set_trace_start_time(times_min.min()) | ||
station.add_channel(channel) | ||
|
||
self.__t += time.time() - t | ||
|
||
def end(self): | ||
|
@@ -267,3 +266,39 @@ def end(self): | |
dt = timedelta(seconds=self.__t) | ||
self.logger.info("total time used by this module is {}".format(dt)) | ||
return dt | ||
|
||
|
||
def calculate_time_shift_for_cosmic_ray(det, station, efield, channel_id): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also not sure if it's worth making this function private, I guess we don´t expect it to be used on its own. But this is anyway not something we're very consistent about throughout NuRadio. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. next PR? Is this function useful for the other efield2Voltage converter modules? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I'm happy for this to be merged as is. |
||
""" | ||
Calculate the time shift for a cosmic ray event | ||
|
||
Parameters | ||
---------- | ||
det : Detector | ||
station : Station | ||
efield : ElectricField | ||
|
||
Returns | ||
------- | ||
float | ||
time shift in ns | ||
""" | ||
station_id = station.get_id() | ||
site = det.get_site(station_id) | ||
antenna_position = det.get_relative_position(station_id, channel_id) - efield.get_position() | ||
if station.get_parameter(stnp.zenith) > 90 * units.deg: # signal is coming from below, so we take IOR of ice | ||
index_of_refraction = ice.get_refractive_index(antenna_position[2], site) | ||
else: # signal is coming from above, so we take IOR of air | ||
index_of_refraction = ice.get_refractive_index(1, site) | ||
|
||
# For cosmic ray events, we only have one electric field for all channels, so we have to account | ||
# for the difference in signal travel between channels. IMPORTANT: This is only accurate | ||
# if all channels have the same z coordinate | ||
travel_time_shift = geo_utl.get_time_delay_from_direction( | ||
station.get_parameter(stnp.zenith), | ||
station.get_parameter(stnp.azimuth), | ||
antenna_position, | ||
index_of_refraction | ||
) | ||
|
||
return travel_time_shift |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what percentage of python users finds this more intuitive than the four-line equivalent... I'm not a convert yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if that is the case: ciao.