Skip to content

Commit

Permalink
Refactored to RandomDistribution
Browse files Browse the repository at this point in the history
  • Loading branch information
gvanhoy committed Sep 5, 2023
1 parent 4aaa6e0 commit 2fbc788
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 553 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,4 @@ lightning_logs/
*.benchmarks/
dist/
examples/*.ipynb_checkpoints/
<<<<<<< HEAD
docs/bin
=======
>>>>>>> 53cb06343cb89ff1e764c6813fe1d71d981cae0f
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,9 @@ dependencies = [
"sympy",
"numba",
"click",
<<<<<<< HEAD
"nbsphinx",
"pypandoc",
"opencv-contrib-python-headless",
=======
>>>>>>> 53cb06343cb89ff1e764c6813fe1d71d981cae0f
]
dynamic = ["version"]

Expand Down
4 changes: 0 additions & 4 deletions scripts/generate_wideband_sig53.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@ def generate(root: str, configs: List[conf.WidebandSig53Config]):
"--all", default=True, help="Generate all versions of wideband_sig53 dataset."
)
@click.option(
<<<<<<< HEAD
"--qa", default=False, help="Generate only QA versions of wideband_sig53 dataset."
=======
"--qa", default=True, help="Generate only QA versions of wideband_sig53 dataset."
>>>>>>> 53cb06343cb89ff1e764c6813fe1d71d981cae0f
)
@click.option(
"--impaired",
Expand Down
4 changes: 0 additions & 4 deletions torchsig/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
<<<<<<< HEAD
__version__ = "0.4.2"
=======
__version__ = "0.4.1"
>>>>>>> 53cb06343cb89ff1e764c6813fe1d71d981cae0f
95 changes: 59 additions & 36 deletions torchsig/datasets/file_datasets.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,9 @@
import json
import os
import xml
import xml.etree.ElementTree as ET
from torchsig.datasets.wideband import BurstSourceDataset, SignalBurst
from typing import Any, List, Optional

import numpy as np
import pandas as pd

from torchsig.datasets.wideband import BurstSourceDataset, SignalBurst
from torchsig.transforms.functional import (
FloatParameter,
NumericParameter,
to_distribution,
uniform_continuous_distribution,
uniform_discrete_distribution,
)
from torchsig.utils.types import SignalDescription
import json
import os


class WidebandFileSignalBurst(SignalBurst):
Expand Down Expand Up @@ -70,7 +58,9 @@ def generate_iq(self):
# Read desired number of samples from file
iq_data = (
np.frombuffer(
file_object.read(int(self.num_iq_samples) * self.bytes_per_sample),
file_object.read(
int(self.num_iq_samples) * self.bytes_per_sample
),
dtype=self.capture_type,
)
.astype(np.float64)
Expand All @@ -82,8 +72,8 @@ def generate_iq(self):
# file repetitively and summing with itself
iq_data = np.zeros(self.num_iq_samples, dtype=np.complex128)
return iq_data[: self.num_iq_samples]


class TargetInterpreter:
"""The TargetInterpreter base class is meant to be inherited and modified
for specific interpreters such that each sub-class implements a transform
Expand Down Expand Up @@ -168,13 +158,16 @@ def convert_to_signalburst(
for label in self.detections_df.iloc[df_indicies].itertuples():
# Determine cut vs full capture relationship
startInWindow = bool(
label.start >= start_sample and label.start < start_sample + self.num_iq_samples
label.start >= start_sample
and label.start < start_sample + self.num_iq_samples
)
stopInWindow = bool(
label.stop > start_sample and label.stop <= start_sample + self.num_iq_samples
label.stop > start_sample
and label.stop <= start_sample + self.num_iq_samples
)
spansFullWindow = bool(
label.start <= start_sample and label.stop >= start_sample + self.num_iq_samples
label.start <= start_sample
and label.stop >= start_sample + self.num_iq_samples
)
fullyContainedInWindow = bool(startInWindow and stopInWindow)

Expand Down Expand Up @@ -337,7 +330,9 @@ def __init__(
self.class_column = class_column
# Generate dataframe
self.detections_df = self._convert_to_dataframe()
self.detections_df = self.detections_df.sort_values(by=["start"]).reset_index(drop=True)
self.detections_df = self.detections_df.sort_values(by=["start"]).reset_index(
drop=True
)
self.num_labels = len(self.detections_df)
self.detections_df = self._convert_class_name_to_index()

Expand Down Expand Up @@ -400,7 +395,9 @@ def __init__(
self.class_target = class_target
# Generate dataframe
self.detections_df = self._convert_to_dataframe()
self.detections_df = self.detections_df.sort_values(by=["start"]).reset_index(drop=True)
self.detections_df = self.detections_df.sort_values(by=["start"]).reset_index(
drop=True
)
self.num_labels = len(self.detections_df)
self.detections_df = self._convert_class_name_to_index()

Expand Down Expand Up @@ -561,7 +558,10 @@ def __init__(

# Distribute randomness evenly over labels, rather than files then labels
# If more than 10,000 files, omit this step for speed
if self.sample_policy == "random_labels" and len(self.target_files) < 10_000:
if (
self.sample_policy == "random_labels"
and len(self.target_files) < 10_000
):
annotations_per_file = []
for file_index, target_file in enumerate(self.target_files):
# Read total file size
Expand All @@ -581,11 +581,14 @@ def __init__(
# Track number of annotations
annotations_per_file.append(len(annotations))
total_annotations = sum(annotations_per_file)
self.file_probabilities = np.asarray(annotations_per_file) / total_annotations
self.file_probabilities = (
np.asarray(annotations_per_file) / total_annotations
)

# Generate the index by creating a set of bursts.
self.index = [
(collection, idx) for idx, collection in enumerate(self._generate_burst_collections())
(collection, idx)
for idx, collection in enumerate(self._generate_burst_collections())
]

def _generate_burst_collections(self) -> List[List[SignalBurst]]:
Expand Down Expand Up @@ -639,7 +642,9 @@ def _generate_burst_collections(self) -> List[List[SignalBurst]]:
self.capture_type.itemsize * 2
)
else:
sample_burst_collection[0].bytes_per_sample = self.capture_type.itemsize
sample_burst_collection[
0
].bytes_per_sample = self.capture_type.itemsize
else:
# Create invalid SignalBurst for data file information only
sample_burst_collection = []
Expand Down Expand Up @@ -703,22 +708,32 @@ def _generate_burst_collections(self) -> List[List[SignalBurst]]:
while null_interval < self.num_iq_samples:
# Randomly sample label index to search around
label_index = np.random.randint(interpreter.num_labels)
if interpreter.num_labels > 1 and label_index + 1 <= interpreter.num_labels - 1:
if (
interpreter.num_labels > 1
and label_index + 1 <= interpreter.num_labels - 1
):
# Max over previous annotation stop and previous null start to handle cases of long signals
null_start_index = max(annotations.iloc[label_index].stop, null_start_index)
null_start_index = max(
annotations.iloc[label_index].stop, null_start_index
)
null_stop_index = annotations.iloc[label_index + 1].start
elif (
interpreter.num_labels > 1 and label_index + 1 > interpreter.num_labels - 1
interpreter.num_labels > 1
and label_index + 1 > interpreter.num_labels - 1
):
# Start start index at end of final label
null_start_index = max(annotations.iloc[label_index].stop, null_start_index)
null_start_index = max(
annotations.iloc[label_index].stop, null_start_index
)
null_stop_index = capture_duration_samples
elif interpreter.num_labels == 1:
# Sample from before or after the only label
before = True if np.random.rand() >= 0.5 else False
null_start_index = 0 if before else annotations.iloc[0].stop
null_stop_index = (
annotations.iloc[0].start if before else capture_duration_samples
annotations.iloc[0].start
if before
else capture_duration_samples
)
else:
# Sample from anywhere in file
Expand Down Expand Up @@ -771,7 +786,9 @@ def _generate_burst_collections(self) -> List[List[SignalBurst]]:
for sample_idx in range(self.num_valid_samples):
if self.sample_policy == "random_labels":
# Sample random file, weighted by number of annotations
file_index = np.random.choice(len(self.data_files), p=self.file_probabilities)
file_index = np.random.choice(
len(self.data_files), p=self.file_probabilities
)
# Read total file size
capture_duration_samples = (
os.path.getsize(os.path.join(self.data_files[file_index]))
Expand Down Expand Up @@ -810,11 +827,15 @@ def _generate_burst_collections(self) -> List[List[SignalBurst]]:
latest_sample_index = burst_start_index + burst_duration / 2
else:
# Long burst: Ensure at least a quarter of the window is occupied
earliest_sample_index = burst_start_index - (0.75 * self.num_iq_samples)
earliest_sample_index = burst_start_index - (
0.75 * self.num_iq_samples
)
latest_sample_index = annotations.iloc[label_index].stop - (
0.25 * self.num_iq_samples
)
data_index = max(0, np.random.randint(earliest_sample_index, latest_sample_index))
data_index = max(
0, np.random.randint(earliest_sample_index, latest_sample_index)
)

# Check duration
if capture_duration_samples - data_index < self.num_iq_samples:
Expand All @@ -838,7 +859,9 @@ def _generate_burst_collections(self) -> List[List[SignalBurst]]:
self.capture_type.itemsize * 2
)
else:
sample_burst_collection[0].bytes_per_sample = self.capture_type.itemsize
sample_burst_collection[
0
].bytes_per_sample = self.capture_type.itemsize

# If sequentially sampling, increment
if self.sample_policy == "sequential_labels":
Expand Down
51 changes: 17 additions & 34 deletions torchsig/datasets/wideband.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,10 @@
from torchsig.transforms.functional import (
FloatParameter,
NumericParameter,
to_distribution,
uniform_continuous_distribution,
uniform_discrete_distribution,
)
from torchsig.utils.dataset import SignalDataset
from torchsig.utils.dsp import low_pass
from torchsig.utils.types import SignalData, SignalDescription
from torchsig.utils.types import SignalData, SignalDescription, RandomDistribution


class SignalBurst(SignalDescription):
Expand Down Expand Up @@ -226,9 +223,8 @@ def __init__(
modulation = self.class_list
else:
modulation = [modulation] if isinstance(modulation, str) else modulation
self.classes = to_distribution(
self.classes = RandomDistribution.to_distribution(
modulation,
random_generator=self.random_generator,
)

# Update freq values
Expand Down Expand Up @@ -526,9 +522,8 @@ def __init__(
**kwargs,
):
super(FileSignalBurst, self).__init__(**kwargs)
self.file_path = to_distribution(
self.file_path = RandomDistribution.to_distribution(
file_path,
random_generator=self.random_generator,
)
self.file_reader = file_reader
self.class_list = class_list
Expand Down Expand Up @@ -674,20 +669,12 @@ def __init__(
self.num_iq_samples = num_iq_samples
self.num_samples = num_samples
self.burst_class = burst_class
self.bandwidths = to_distribution(
bandwidths, random_generator=self.random_generator
)
self.center_frequencies = to_distribution(
center_frequencies, random_generator=self.random_generator
)
self.burst_durations = to_distribution(
burst_durations, random_generator=self.random_generator
)
self.silence_durations = to_distribution(
silence_durations, random_generator=self.random_generator
)
self.snrs_db = to_distribution(snrs_db, random_generator=self.random_generator)
self.start = to_distribution(start, random_generator=self.random_generator)
self.bandwidths = RandomDistribution.to_distribution(bandwidths)
self.center_frequencies = RandomDistribution.to_distribution(center_frequencies)
self.burst_durations = RandomDistribution.to_distribution(burst_durations)
self.silence_durations = RandomDistribution.to_distribution(silence_durations)
self.snrs_db = RandomDistribution.to_distribution(snrs_db)
self.start = RandomDistribution.to_distribution(start)

# Generate the index by creating a set of bursts.
self.index = [
Expand Down Expand Up @@ -720,7 +707,6 @@ def _generate_burst_collections(self) -> List[List[SignalBurst]]:
center_frequency=center_frequency,
bandwidth=bandwidth,
snr=snr,
random_generator=self.random_generator,
)
)
start = start + burst_duration + silence_duration
Expand Down Expand Up @@ -1024,10 +1010,8 @@ def __init__(
)
self.target_transform = target_transform

self.num_signals = to_distribution(
num_signals, random_generator=self.random_generator
)
self.snrs = to_distribution(snrs, random_generator=self.random_generator)
self.num_signals = RandomDistribution.to_distribution(num_signals)
self.snrs = RandomDistribution.to_distribution(snrs)

def __gen_metadata__(self, modulation_list: List) -> pd.DataFrame:
"""This method defines the parameters of the modulations to be inserted
Expand Down Expand Up @@ -1135,13 +1119,11 @@ def __getitem__(self, item: int) -> Tuple[np.ndarray, Any]:
):
# Signal is bursty
bursty = True
burst_duration = to_distribution(
burst_duration = RandomDistribution.to_distribution(
literal_eval(self.metadata.iloc[meta_idx].burst_duration),
random_generator=self.random_generator,
)()
silence_multiple = to_distribution(
silence_multiple = RandomDistribution.to_distribution(
literal_eval(self.metadata.iloc[meta_idx].silence_multiple),
random_generator=self.random_generator,
)()
stops_in_frame = False
if hop_random_var < self.metadata.iloc[meta_idx].freq_hopping_prob:
Expand All @@ -1152,11 +1134,10 @@ def __getitem__(self, item: int) -> Tuple[np.ndarray, Any]:
bandwidth = self.random_generator.uniform(0.025, 0.05)

silence_duration = burst_duration * (silence_multiple - 1)
freq_channels = to_distribution(
freq_channels = RandomDistribution.to_distribution(
literal_eval(
self.metadata.iloc[meta_idx].freq_hopping_channels
),
random_generator=self.random_generator,
)()

# Convert channel count to list of center frequencies
Expand Down Expand Up @@ -1486,7 +1467,9 @@ def __call__(self, data: Any) -> Any:
(x + bandwidth / 2, y - bandwidth / 2) for x, y in unoccupied_bands
]
rand_band_idx = np.random.randint(len(center_freqs))
center_freqs_dist = to_distribution(center_freqs[rand_band_idx])
center_freqs_dist = RandomDistribution.to_distribution(
center_freqs[rand_band_idx]
)
center_freq = center_freqs_dist()
bursty = True if np.random.rand() < 0.5 else False
burst_duration = np.random.uniform(0.05, 1.0) if bursty else 1.0
Expand Down
Loading

0 comments on commit 2fbc788

Please sign in to comment.