Skip to content

Commit

Permalink
Final touches
Browse files Browse the repository at this point in the history
  • Loading branch information
gvanhoy committed Sep 5, 2023
1 parent 2fbc788 commit 4964f48
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 54 deletions.
106 changes: 53 additions & 53 deletions tests/test_transforms_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,60 +25,60 @@ def generate_data():


transforms_list = [
# (
# "random_resample_up",
# RandomResample(1.5, num_iq_samples=128, keep_samples=False),
# RandomResample(1.5, num_iq_samples=4096, keep_samples=False),
# ),
# (
# "random_resample_down",
# RandomResample(0.75, num_iq_samples=128, keep_samples=False),
# RandomResample(0.75, num_iq_samples=4096, keep_samples=False),
# ),
(
"random_resample_up",
RandomResample(1.5, num_iq_samples=128, keep_samples=False),
RandomResample(1.5, num_iq_samples=4096, keep_samples=False),
),
(
"random_resample_down",
RandomResample(0.75, num_iq_samples=128, keep_samples=False),
RandomResample(0.75, num_iq_samples=4096, keep_samples=False),
),
("add_noise", AddNoise(-10), AddNoise(-10)),
# ("time_varying_noise", TimeVaryingNoise(-30, -10), TimeVaryingNoise(-30, -10)),
# (
# "rayleigh_fading",
# RayleighFadingChannel(0.05, (1.0, 0.5, 0.1)),
# RayleighFadingChannel(0.05, (1.0, 0.5, 0.1)),
# ),
# ("phase_shift", RandomPhaseShift(0.5), RandomPhaseShift(0.5)),
# ("time_shift", RandomTimeShift(-100.5), RandomTimeShift(-2.5)),
# (
# "time_crop",
# TimeCrop("random", length=128),
# TimeCrop("random", length=4096),
# ),
# ("time_reversal", TimeReversal(False), TimeReversal(False)),
# ("frequency_shift", RandomFrequencyShift(-0.25), RandomFrequencyShift(-0.25)),
# (
# "delayed_frequency_shift",
# RandomDelayedFrequencyShift(0.2, 0.25),
# RandomDelayedFrequencyShift(0.2, 0.25),
# ),
# (
# "oscillator_drift",
# LocalOscillatorDrift(0.01, 0.001),
# LocalOscillatorDrift(0.01, 0.001),
# ),
# ("gain_drift", GainDrift(0.01, 0.001, 0.1), GainDrift(0.01, 0.001, 0.1)),
# (
# "iq_imbalance",
# IQImbalance(3, np.pi / 180, 0.05),
# IQImbalance(3, np.pi / 180, 0.05),
# ),
# ("roll_off", RollOff(0.05, 0.98), RollOff(0.05, 0.98)),
# ("add_slope", AddSlope(), AddSlope()),
# ("spectral_inversion", SpectralInversion(), SpectralInversion()),
# ("channel_swap", ChannelSwap(), ChannelSwap()),
# ("magnitude_rescale", RandomMagRescale(0.5, 3), RandomMagRescale(0.5, 3)),
# (
# "drop_samples",
# RandomDropSamples(0.01, 50, ["zero"]),
# RandomDropSamples(0.01, 50, ["zero"]),
# ),
# ("quantize", Quantize(32, ["floor"]), Quantize(32, ["floor"])),
# ("clip", Clip(0.85), Clip(0.85)),
("time_varying_noise", TimeVaryingNoise(-30, -10), TimeVaryingNoise(-30, -10)),
(
"rayleigh_fading",
RayleighFadingChannel(0.05, (1.0, 0.5, 0.1)),
RayleighFadingChannel(0.05, (1.0, 0.5, 0.1)),
),
("phase_shift", RandomPhaseShift(0.5), RandomPhaseShift(0.5)),
("time_shift", RandomTimeShift(-100.5), RandomTimeShift(-2.5)),
(
"time_crop",
TimeCrop("random", length=128),
TimeCrop("random", length=4096),
),
("time_reversal", TimeReversal(False), TimeReversal(False)),
("frequency_shift", RandomFrequencyShift(-0.25), RandomFrequencyShift(-0.25)),
(
"delayed_frequency_shift",
RandomDelayedFrequencyShift(0.2, 0.25),
RandomDelayedFrequencyShift(0.2, 0.25),
),
(
"oscillator_drift",
LocalOscillatorDrift(0.01, 0.001),
LocalOscillatorDrift(0.01, 0.001),
),
("gain_drift", GainDrift(0.01, 0.001, 0.1), GainDrift(0.01, 0.001, 0.1)),
(
"iq_imbalance",
IQImbalance(3, np.pi / 180, 0.05),
IQImbalance(3, np.pi / 180, 0.05),
),
("roll_off", RollOff(0.05, 0.98), RollOff(0.05, 0.98)),
("add_slope", AddSlope(), AddSlope()),
("spectral_inversion", SpectralInversion(), SpectralInversion()),
("channel_swap", ChannelSwap(), ChannelSwap()),
("magnitude_rescale", RandomMagRescale(0.5, 3), RandomMagRescale(0.5, 3)),
(
"drop_samples",
RandomDropSamples(0.01, 50, ["zero"]),
RandomDropSamples(0.01, 50, ["zero"]),
),
("quantize", Quantize(32, ["floor"]), Quantize(32, ["floor"])),
("clip", Clip(0.85), Clip(0.85)),
]


Expand Down
2 changes: 1 addition & 1 deletion torchsig/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2969,7 +2969,7 @@ class Quantize(SignalTransform):
def __init__(
self,
num_levels: IntParameter = UniformDiscreteRD(
np.asarray([16, 24, 32, 40, 48, 56, 64])
np.asarray([16, 24, 32, 40, 48, 56, 64], dtype=int)
),
round_type: List[str] = (["floor", "middle", "ceiling"]),
) -> None:
Expand Down
6 changes: 6 additions & 0 deletions torchsig/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def to_distribution(dist):
if isinstance(dist, tuple):
return UniformContinuousRD(dist[0], dist[1])

if isinstance(dist, list):
return UniformDiscreteRD(dist)

def __call__(self, num: int = 1):
raise NotImplementedError

Expand All @@ -26,6 +29,9 @@ def __init__(self, constant: float) -> None:
self.constant = constant

def __call__(self, num: int = 1):
if num == 1:
return self.constant

return np.repeat(self.constant, repeats=num)


Expand Down

0 comments on commit 4964f48

Please sign in to comment.