Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cw filter notch module #724

Merged
merged 17 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion 1

This file was deleted.

78 changes: 40 additions & 38 deletions NuRadioReco/modules/channelCWNotchFilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,24 @@
from NuRadioReco.utilities import fft

"""
Contains function to filter continuous wave out of the signal
functions should work on a per event basis to comply with the iteration methods used in readRNOGData
Contains module to filter continuous wave out of the signal using notch filters
on peaks in frequency spectrum
"""


def find_frequency_peaks_from_trace(trace : np.ndarray, fs : float = 3.2e9 * units.Hz, threshold = 4):
def find_frequency_peaks_from_trace(trace : np.ndarray, fs : float, nr_samples : int, threshold=4):
"""
Function fo find the frequency peaks in the real fourier transform of the input trace,

Parameters
----------
trace : np.ndarray
waveform (shape: [24,2048])
fs: float, default = 3.2e9 Hz
sampling frequency of the RNO-G DAQ, (input should be taking from the channel object)
threshold : int, default = 4
fs: float,
sampling frequency , (input should be taking from the channel object)
nr_samples : int
number of samples in a time trace
threshold, default = 4
threshold for peak definition. A peak is defined as a point in the frequency spectrum
that exceeds threshold * rms(real fourier transform)

Expand All @@ -28,24 +30,24 @@ def find_frequency_peaks_from_trace(trace : np.ndarray, fs : float = 3.2e9 * uni
freq_peaks : np.ndarray
frequencies at which a peak was found
"""
freq = np.fft.rfftfreq(2048, d = 1/fs)
freq = np.fft.rfftfreq(nr_samples, d = 1/fs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think len(trace) would be fine here.

ft = fft.time2freq(trace, fs)

freq_peaks = find_frequency_peaks(freq, ft, fs = fs, threshold = threshold)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the pep8 rule includes this line too


return freq_peaks

def find_frequency_peaks(freq: np.ndarray, spectrum : np.ndarray, threshold = 4):
def find_frequency_peaks(freq: np.ndarray, spectrum : np.ndarray, threshold=4):
"""
Function fo find the frequency peaks in the real fourier transform of the input trace,

Parameters
----------
freq : np.ndarray
frequencies of a NuRadio rime trace
frequencies of a NuRadio time trace
spectrum : np.ndarray
spectrum of a NuRadio time trace
threshold : int, default = 4
threshold, default = 4
threshold for peak definition. A peak is defined as a point in the frequency spectrum
that exceeds threshold * rms(real fourier transform)

Expand All @@ -61,7 +63,7 @@ def find_frequency_peaks(freq: np.ndarray, spectrum : np.ndarray, threshold = 4)
return freq[peak_idxs]


def filter_cws(trace : np.ndarray, freq : np.ndarray, spectrum : np.ndarray, fs = 3.2e9 * units.Hz, quality_factor = 1e3, threshold = 4):
def filter_cws(trace : np.ndarray, freq : np.ndarray, spectrum : np.ndarray, fs=3.2e9 * units.Hz, quality_factor=1e3, threshold=4):
"""
Function that applies a notch filter at the frequency peaks of a given time trace
using the scipy library
Expand All @@ -84,9 +86,9 @@ def filter_cws(trace : np.ndarray, freq : np.ndarray, spectrum : np.ndarray, fs
that exceeds threshold * rms(real fourier transform)

"""
freqs = find_frequency_peaks(freq, spectrum, threshold = threshold)
freqs = find_frequency_peaks(freq, spectrum, threshold=threshold)

if len(freqs) !=0:
if len(freqs):
notch_filters = [signal.iirnotch(freq, quality_factor, fs = fs) for freq in freqs]
trace_notched = signal.filtfilt(notch_filters[0][0], notch_filters[0][1], trace)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filtfilt will pad by default, no? (Maybe I'm misremembering)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it does pad, I see two issues:

  1. you get a longer trace than you put in, which is perhaps surprising
  2. you can get weird transients at the former edge of the waveform, where you'll essentially see the step response of the filter

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Cosmin, thank you for the input. I am very much not experienced with signal analysis / filters so I might be asking some obvious questions in what follows.
Filtfit indeed pads the waveform. I looked in the documentation and it applies an "odd padding" (extending the waveform by copying a range of samples at the edge, rotated by 180 degrees). The trace that is returned is the same shape as the input trace so after filter application it cuts the padding. Is it a problem to use padding? In my naive estimation I thought the filter would lead to some artifacts at the edges when not using any padding? Do you mean there would be transients due to the step reponse if the the wavefrom was padded with 0's?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it does indeed crop the padding applied then that's good and we don't have to worry about that part.

And yes, I'm somewhat worried about the step response of the filter at the edge after padding, since you'll be going from 0 to sometimes a non-zero value that sometimes will be large (i.e. > 3 sigma 0.3% of the time), but I guess as long as we never consider anything on the edge of the waveforms it's probably not a problem?

Copy link
Collaborator Author

@CamphynR CamphynR Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I have set the padding to None (not sure how Scipy then arrives at the same trace length though, but it does), Odd padding is also an option as it basically "extends" the waveform. I do not have a feeling about which option is better.

for notch in notch_filters[1:]:
Expand All @@ -96,7 +98,7 @@ def filter_cws(trace : np.ndarray, freq : np.ndarray, spectrum : np.ndarray, fs
return trace


def plot_trace(channel, ax, fs = 3.2e9 * units.Hz, label = None, plot_kwargs = dict()):
def plot_trace(channel, ax, fs=3.2e9 * units.Hz, label=None, plot_kwargs=dict()):
"""
Function to plot trace of given channel

Expand All @@ -118,13 +120,13 @@ def plot_trace(channel, ax, fs = 3.2e9 * units.Hz, label = None, plot_kwargs = d

legendloc = 2

ax.plot(times, trace, label = label, **plot_kwargs)
ax.plot(times, trace, label=label, **plot_kwargs)
ax.set_xlabel("time / ns")
ax.set_ylabel("trace / V")
ax.legend(loc = legendloc)
ax.legend(loc=legendloc)


def plot_ft(channel, ax, label = None, plot_kwargs = dict()):
def plot_ft(channel, ax, label=None, plot_kwargs=dict()):
"""
Function to plot real frequency spectrum of given channel

Expand All @@ -144,7 +146,7 @@ def plot_ft(channel, ax, label = None, plot_kwargs = dict()):

legendloc = 2

ax.plot(freqs, np.abs(spec), label = label, **plot_kwargs)
ax.plot(freqs, np.abs(spec), label=label, **plot_kwargs)
ax.set_xlabel("freq / GHz")
ax.set_ylabel("amplitude / V/GHz")
ax.legend(loc = legendloc)
Expand All @@ -158,7 +160,7 @@ class channelCWNotchFilter():
def __init__(self):
pass

def begin(self, quality_factor = 1e3, threshold = 4):
def begin(self, quality_factor=1e3, threshold=4):
self.quality_factor = quality_factor
self.threshold = threshold

Expand All @@ -168,7 +170,7 @@ def run(self, event, station, det):
freq = channel.get_frequencies()
spectrum = channel.get_frequency_spectrum()
trace = channel.get_trace()
trace_fil = filter_cws(trace, freq, spectrum, quality_factor = self.quality_factor, threshold = self.threshold, fs = fs)
trace_fil = filter_cws(trace, freq, spectrum, quality_factor=self.quality_factor, threshold=self.threshold, fs=fs)
channel.set_trace(trace_fil, fs)

# Standard test for people playing around with module settings, applies the module as one would in a data reading pipeline
Expand All @@ -181,16 +183,16 @@ def run(self, event, station, det):

from NuRadioReco.modules.io.RNO_G.readRNOGDataMattak import readRNOGData

parser = argparse.ArgumentParser(prog = "%(prog)s", usage = "cw filter test")
parser.add_argument("--station", type = int, default = 24)
parser.add_argument("--run", type = int, default = 1)
parser = argparse.ArgumentParser(prog="%(prog)s", usage="cw filter test")
parser.add_argument("--station", type=int, default=24)
parser.add_argument("--run", type=int, default=1)

parser.add_argument("--quality_factor", type = int, default = 1e3)
parser.add_argument("--threshold", type = int, default = 4)
parser.add_argument("--fs", type = float, default = 3.2e9 * units.Hz)
parser.add_argument("--quality_factor", type=int, default=1e3)
parser.add_argument("--threshold", type=int, default=4)
parser.add_argument("--fs", type=float, default=3.2e9 * units.Hz)

parser.add_argument("--save_dir", type = str, default = None,
help = "Directory where to save plot produced by the test.\
parser.add_argument("--save_dir", type=str, default=None,
help="Directory where to save plot produced by the test.\
If None, saves to NuRadioReco test directory")

args = parser.parse_args()
Expand All @@ -200,28 +202,28 @@ def run(self, event, station, det):

root_dirs = f"{data_dir}/station{args.station}/run{args.run}"
rnog_reader.begin(root_dirs,
convert_to_voltage = True, # linear voltage calibration
mattak_kwargs = dict(backend = "uproot"))
convert_to_voltage=True, # linear voltage calibration
mattak_kwargs=dict(backend="uproot"))

channelCWNotchFilter = channelCWNotchFilter()
channelCWNotchFilter.begin(quality_factor = args.quality_factor, threshold = args.threshold)
channelCWNotchFilter.begin(quality_factor=args.quality_factor, threshold=args.threshold)

for event in rnog_reader.run():
station_id = event.get_station_ids()[0]
station = event.get_station(station_id)

fig, axs = plt.subplots(1, 2, figsize = (14, 6))
plot_trace(station.get_channel(0), axs[0], label = "before")
plot_ft(station.get_channel(0), axs[1], label = "before")
channelCWNotchFilter.run(event, station, det = 0)
plot_trace(station.get_channel(0), axs[0], label = "after")
plot_ft(station.get_channel(0), axs[1], label = "after")
fig, axs = plt.subplots(1, 2, figsize=(14, 6))
plot_trace(station.get_channel(0), axs[0], label="before")
plot_ft(station.get_channel(0), axs[1], label="before")
channelCWNotchFilter.run(event, station, det=0)
plot_trace(station.get_channel(0), axs[0], label="after")
plot_ft(station.get_channel(0), axs[1], label="after")

if args.save_dir is None:
fig_dir = os.path.abspath(f"{__file__}/../../test")
else:
fig_dir = args.save_dir


fig.savefig(f"{fig_dir}/test_cw_filter", bbox_inches = "tight")
fig.savefig(f"{fig_dir}/test_cw_filter", bbox_inches="tight")
break