Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cw filter notch module #724

Merged
merged 17 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 295 additions & 0 deletions NuRadioReco/modules/channelCWNotchFilter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
import logging
logger = logging.getLogger("NuRadioReco.channelCWNotchFilter")
import time
import numpy as np
from scipy import signal
from NuRadioReco.utilities import units
from NuRadioReco.utilities import fft

"""
Contains module to filter continuous wave out of the signal using notch filters
on peaks in frequency spectrum
"""


def find_frequency_peaks_from_trace(trace : np.ndarray, fs : float, threshold : float = 4):
"""
Function fo find the frequency peaks in the real fourier transform of the input trace.

Parameters
----------
trace : np.ndarray
Waveform
fs : float
Sampling frequency, (input should be taking from the channel object)
threshold : float, default = 4
Threshold for peak definition. A peak is defined as a point in the frequency spectrum
that exceeds threshold * rms(real fourier transform)

Returns
-------
freq_peaks : np.ndarray
Frequencies at which a peak was found
"""
freq = np.fft.rfftfreq(len(trace), d=1/fs)
ft = fft.time2freq(trace, fs)

freq_peaks = find_frequency_peaks(freq, ft, fs=fs, threshold=threshold)

return freq_peaks


def find_frequency_peaks(freq: np.ndarray, spectrum : np.ndarray, threshold : float = 4):
"""
Function fo find the frequency peaks in the real fourier transform of the input trace.

Parameters
----------
freq : np.ndarray
Frequencies of a NuRadio time trace
spectrum : np.ndarray
Spectrum of a NuRadio time trace
threshold : float, default = 4
Threshold for peak definition. A peak is defined as a point in the frequency spectrum
that exceeds threshold * rms(real fourier transform)

Returns
-------
freq : np.ndarray
Frequencies at which a peak was found
"""

rms = np.sqrt(np.mean(np.abs(spectrum)**2))
peak_idxs = np.where(np.abs(spectrum) > threshold * rms)[0]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't the RMS of the trace be dominated by the filter shape? I wonder if it makes sense to limit the frequency range over which the RMS is calculated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With filter shape do you mean the CW peak?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what I mean is if you take a noiseless neutrino or calpulser with no CW, the RMS will mostly come from the fact that we have a high pass and low pass filters so the spectrum has a shape in it.


return freq[peak_idxs]


def get_filter(freq : int, fs, quality_factor=1e3, cache=None):
"""
Function to get single notch filter for a given frequency.

Parameters
----------
freq : np.ndarray
Frequency
fs : float,
sampling frequency in MHz
quality_factor : int, default = 1000
quality factor of the notch filter, defined as the ratio f0/bw, where f0 is the centre frequency
and bw the bandwidth of the filter at (f0,-3 dB)
cache : dict, default = None,
Optional caching dictionary. The function will check whether the frequency to be filtered
is in the dictionary values and will otherwise add it
!!! Note this does not cache the quality factor information

Returns
-------
filter : list, shape (6)
second order IIR notch filter at frequency freq
"""
if cache is not None:
if freq in cache.keys():
return cache[freq]
filter = signal.iirnotch(freq, quality_factor, fs=fs)
if cache is not None:
# Check to avoid cache dictionary overflowing the memory,
# set to roughly stay below 6 MB (every filter is 6 floats + freq = 7 floats ~ 56B)
if len(cache.keys()) < 1e5:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI the .keys() is not necessary but it is more verbose.

cache[freq] = filter
return filter


def filter_cws(trace : np.ndarray, freq : np.ndarray, spectrum : np.ndarray, fs : float, quality_factor=1e3, threshold=4,
cache : dict = None,
filters : list = None):
"""
Function that applies a notch filter at the frequency peaks of a given time trace
using the scipy library

Parameters
----------
trace : np.ndarray
waveform (shape: [2048])
freq : np.ndarray
Frequency of the trace's real fourier transform
spectrum:
the trace's real fourier transform
fs : float
sampling frequency in MHz
quality_factor : int, default = 1000
quality factor of the notch filter, defined as the ratio f0/bw, where f0 is the centre frequency
and bw the bandwidth of the filter at (f0,-3 dB)
threshold : int, default = 4
threshold for peak definition. A peak is defined as a point in the frequency spectrum
that exceeds threshold * rms(real fourier transform)
cache : dict, default = None,
Optional caching dictionary. The function will check whether the frequency to be filtered
is in the dictionary values and will otherwise add it
!!! Note this assumes the quality_factor is the same for all notch filters!!!
filters : NoneType or list, default = None
Optional list to which the filters used in this function can be appended for future reference

Returns
-------
trace : np.ndarray
CW-filtered trace
"""
freqs = find_frequency_peaks(freq, spectrum, threshold=threshold)

if len(freqs):
# the array is reshaped to (nr_of_filters, nr_of_coefficients), since iirnotch is a second order IIR,
# the nr_of_coefficients will be 6: 3 for the numerator and 3 for the denumerator, in that order
notch_filters = np.array([get_filter(freq, fs, quality_factor, cache=cache) for freq in freqs]).reshape(-1, 6)
if filters is not None:
filters.append(notch_filters)
logging.debug(f"Shape of notch filters for one channel is: {notch_filters.shape}")
trace_notched = signal.sosfiltfilt(notch_filters, trace, padtype = None)
return trace_notched
else:
# append empty list when filters is specified to ensure
# filters list is shape 24 when looping over channels
if filters is not None:
filters.append([])

return trace


def plot_trace(channel, ax, fs=3.2e9*units.Hz, label=None, plot_kwargs=dict()):
"""
Function to plot trace of given channel

Parameters
----------
channel : NuRadio channel class
channel from which to get trace
ax : matplotlib.axes
ax on which to plot
fs : float, default = 3.2 Hz
sampling frequency
label : string
plotlabel
plot_kwargs : dict
options for plotting
"""
times = np.arange(2048)/fs / units.ns
trace = channel.get_trace()

legendloc = 2

ax.plot(times, trace, label=label, **plot_kwargs)
ax.set_xlabel("time / ns")
ax.set_ylabel("trace / V")
ax.legend(loc=legendloc)


def plot_ft(channel, ax, label=None, plot_kwargs=dict()):
"""
Function to plot real frequency spectrum of given channel

Parameters
----------
channel : NuRadio channel class
channel from which to get trace
ax : matplotlib.axes
ax on which to plot
label : string
plotlabel
plot_kwargs : dict
options for plotting
"""
freqs = channel.get_frequencies()
spec = channel.get_frequency_spectrum()

legendloc = 2

ax.plot(freqs, np.abs(spec), label=label, **plot_kwargs)
ax.set_xlabel("freq / GHz")
ax.set_ylabel("amplitude / V/GHz")
ax.legend(loc = legendloc)


class channelCWNotchFilter():
""" Continuous wave (CW) filter module. Uses notch filters from the scipy library """

def __init__(self):
pass

def begin(self, quality_factor=1e3, threshold=4, save_filters=False):
self.quality_factor = quality_factor
self.threshold = threshold
self.filters = [] if save_filters else None
# dictionary to cache known notch filters at specific frequencies
self.filter_cache = {}

def run(self, event, station, det):
for channel in station.iter_channels():
fs = channel.get_sampling_rate()
freq = channel.get_frequencies()
spectrum = channel.get_frequency_spectrum()
trace = channel.get_trace()
trace_fil = filter_cws(
trace, freq, spectrum, fs, quality_factor=self.quality_factor, threshold=self.threshold,
cache=self.filter_cache, filters=self.filters)

channel.set_trace(trace_fil, fs)


# Standard test for people playing around with module settings, applies the module as one would in a data reading pipeline
# using one event in RNO_G_DATA (choose station and run) as a test
if __name__ == "__main__":
import os
import logging
import argparse
import matplotlib.pyplot as plt

from NuRadioReco.modules.io.RNO_G.readRNOGDataMattak import readRNOGData

parser = argparse.ArgumentParser(prog="%(prog)s", usage="cw filter test")
parser.add_argument("--station", type=int, default=24)
parser.add_argument("--channel", type = int, default = 0)
parser.add_argument("--run", type=int, default=1)

parser.add_argument("--quality_factor", type=int, default=1e3)
parser.add_argument("--threshold", type=int, default=4)
parser.add_argument("--fs", type=float, default=3.2e9 * units.Hz)

parser.add_argument("--save_dir", type=str, default=None,
help="Directory where to save plot produced by the test.\
If None, saves to NuRadioReco test directory")

args = parser.parse_args()

data_dir = os.environ["RNO_G_DATA"]
rnog_reader = readRNOGData(log_level = logging.DEBUG)

root_dirs = f"{data_dir}/station{args.station}/run{args.run}"
rnog_reader.begin(root_dirs,
# linear voltage calibration
convert_to_voltage=True,
mattak_kwargs=dict(backend="uproot"))

channelCWNotchFilter = channelCWNotchFilter()
channelCWNotchFilter.begin(quality_factor=args.quality_factor, threshold=args.threshold)

for event in rnog_reader.run():
station_id = event.get_station_ids()[0]
station = event.get_station(station_id)

fig, axs = plt.subplots(1, 2, figsize=(14, 6))
plot_trace(station.get_channel(args.channel), axs[0], label="before")
plot_ft(station.get_channel(args.channel), axs[1], label="before")
t0 = time.time()
channelCWNotchFilter.run(event, station, det=0)
logger.debug(f"Filter took {time.time() - t0} s to run.")
plot_trace(station.get_channel(args.channel), axs[0], label="after")
plot_ft(station.get_channel(args.channel), axs[1], label="after")

if args.save_dir is None:
fig_dir = os.path.abspath(f"{__file__}/../../test")
else:
fig_dir = args.save_dir


fig.savefig(f"{fig_dir}/test_cw_filter", bbox_inches="tight")
break
1 change: 1 addition & 0 deletions changelog.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ for the phasing_mode == "slice"
- Improved logging: created NuRadioLogger class which is now used by default, the NRR and NRMC loggers
are created automatically when importing the packages, and a new STATUS logging level was added.
- Updating the ray_tracing_2D class in order to be able to use Numba optimization
- Add new cw filter module to NuRadioReco using notch filters on peaks in the frequency spectrum

bugfixes:
- Fixed bug in get_travel_time in directRayTracing propagation module
Expand Down