Skip to content

Commit

Permalink
removed unused print statements add collate_fn in generate_sig53.py …
Browse files Browse the repository at this point in the history
…to avoid torch Dataloader from converting the labels to tensors.

	modified:   scripts/generate_sig53.py
	modified:   scripts/generate_wideband_sig53.py
	modified:   scripts/train_sig53.py
	modified:   torchsig/datasets/modulations.py
	modified:   torchsig/datasets/sig53.py
	modified:   torchsig/datasets/synthetic.py
	modified:   torchsig/datasets/wideband.py
	modified:   torchsig/datasets/wideband_sig53.py
	modified:   torchsig/transforms/target_transforms.py
	modified:   torchsig/transforms/transforms.py
	modified:   torchsig/utils/visualize.py
	modified:   torchsig/utils/writer.py
  • Loading branch information
pvallance committed Jul 13, 2024
1 parent 691302c commit cce5329
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 94 deletions.
22 changes: 7 additions & 15 deletions scripts/generate_sig53.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import os
import numpy as np

def collate_fn(batch):
return tuple(zip(*batch))


def generate(path: str, configs: List[conf.Sig53Config], num_workers: int, num_samples_override: int):
num_samples = config.num_samples if num_samples == 0 else num_samples
num_samples = config.num_samples if num_samples_override == 0 else num_samples_override
for config in configs:
num_samples = config.num_samples if num_samples_override <=0 else num_samples_override
batch_size = int(np.min((config.num_samples // num_workers, 32)))
print(f'batch_size -> {batch_size} num_samples -> {num_samples}')
print(f'batch_size -> {batch_size} num_samples -> {num_samples}, config -> {config}')
ds = ModulationsDataset(
level=config.level,
num_samples=num_samples,
Expand All @@ -21,19 +24,8 @@ def generate(path: str, configs: List[conf.Sig53Config], num_workers: int, num_s
include_snr=config.include_snr,
eb_no=config.eb_no,
)
loader = DatasetLoader(
ds,
seed=12345678,
num_workers=num_workers,
batch_size=batch_size,
)
creator = DatasetCreator(
ds,
seed=12345678,
path="{}".format(os.path.join(path, config.name)),
loader=loader,
num_workers=num_workers,
)
dataset_loader = DatasetLoader(ds, seed=12345678, collate_fn=collate_fn, num_workers=num_workers, batch_size=batch_size)
creator = DatasetCreator(ds, seed=12345678, path="{}".format(os.path.join(path, config.name)), loader=dataset_loader, num_workers=num_workers)
creator.create()


Expand Down
103 changes: 76 additions & 27 deletions scripts/generate_wideband_sig53.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,99 @@
from torchsig.utils.writer import DatasetCreator, DatasetLoader
from torchsig.datasets.wideband import WidebandModulationsDataset
from torchsig.datasets import conf
from torchsig.transforms.transforms import *
from typing import List
import click
import os
import numpy as np

import numpy as np

def collate_fn(batch):
return tuple(zip(*batch))


def generate(root: str, configs: List[conf.WidebandSig53Config], num_workers: int):
modulation_list = ["ook",
"bpsk",
"4pam",
"4ask",
"qpsk",
"8pam",
"8ask",
"8psk",
"16qam",
"16pam",
"16ask",
"16psk",
"32qam",
"32qam_cross",
"32pam",
"32ask",
"32psk",
"64qam",
"64pam",
"64ask",
"64psk",
"128qam_cross",
"256qam",
"512qam_cross",
"1024qam",
"2fsk",
"2gfsk",
"2msk",
"2gmsk",
"4fsk",
"4gfsk",
"4msk",
"4gmsk",
"8fsk",
"8gfsk",
"8msk",
"8gmsk",
"16fsk",
"16gfsk",
"16msk",
"16gmsk",
"ofdm-64",
"ofdm-72",
"ofdm-128",
"ofdm-180",
"ofdm-256",
"ofdm-300",
"ofdm-512",
"ofdm-600",
"ofdm-900",
"ofdm-1024",
"ofdm-1200",
"ofdm-2048",
]

def generate(root: str, configs: List[conf.WidebandSig53Config], num_workers: int, num_samples_override: int):
for config in configs:
batch_size = int(np.min((config.num_samples // num_workers, 32)))
num_samples = config.num_samples if num_samples_override <=0 else num_samples
batch_size = int(np.min((num_samples // num_workers, 32)))
print(f'batch_size -> {batch_size} num_samples -> {num_samples}, config -> {config}')
wideband_ds = WidebandModulationsDataset(
level=config.level,
num_iq_samples=config.num_iq_samples,
num_samples=config.num_samples,
num_samples=num_samples,
modulation_list=modulation_list,
target_transform=DescToListTuple(),
seed=config.seed,
)

dataset_loader = DatasetLoader(
wideband_ds,
seed=12345678,
collate_fn=collate_fn,
num_workers=num_workers,
batch_size=batch_size,
)
creator = DatasetCreator(
wideband_ds,
seed=12345678,
path=os.path.join(root, config.name),
loader=dataset_loader,
num_workers=num_workers,
)
dataset_loader = DatasetLoader(wideband_ds, seed=12345678, collate_fn=collate_fn, num_workers=num_workers, batch_size=batch_size)
creator = DatasetCreator(wideband_ds, seed=12345678, num_workers=num_workers, path=os.path.join(root, config.name), loader=dataset_loader,)
creator.create()


@click.command()
@click.option("--root", default="wideband_sig53", help="Path to generate wideband_sig53 datasets")
@click.option("--all", default=True, help="Generate all versions of wideband_sig53 dataset.")
@click.option("--qa", default=False, help="Generate only QA versions of wideband_sig53 dataset.")
@click.option("--num-samples", default=-1, help="Override for number of dataset samples.")
@click.option("--impaired", default=False, help="Generate impaired dataset. Ignored if --all=True (default)",)
@click.option("--num-workers", "num_workers", default=os.cpu_count() // 2, help="Define number of workers for both DatasetLoader and DatasetCreator")
@click.option("--impaired", default=False, help="Generate impaired dataset. Ignored if --all=True (default)")
def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int):
def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int, num_samples: int):
if not os.path.isdir(root):
os.mkdir(root)

Expand All @@ -60,20 +108,21 @@ def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int):
conf.WidebandSig53ImpairedTrainQAConfig,
conf.WidebandSig53ImpairedValQAConfig,
]

if all:
generate(root, configs[:4], num_workers)
generate(root, configs, num_workers, num_samples)
return

if qa:
generate(root, configs[4:], num_workers)
elif qa:
generate(root, configs[4:], num_workers, num_samples)
return

if impaired:
generate(root, configs[2:4], num_workers)
elif impaired:
generate(root, configs[2:], num_workers, num_samples)
return

generate(root, configs[:2], num_workers)
else:
generate(root, configs[:2], num_workers, num_samples)


if __name__ == "__main__":
Expand Down
17 changes: 4 additions & 13 deletions scripts/train_sig53.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from sklearn.metrics import classification_report
from torchsig.utils.cm_plotter import plot_confusion_matrix
from torchsig.datasets.sig53 import Sig53
from torchsig.datasets import conf
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from torch import optim
Expand Down Expand Up @@ -80,11 +79,6 @@ def main(root: str, impaired: bool):
)
target_transform = DescToClassIndex(class_list=class_list)

if impaired == True:
num_samples = conf.Sig53ImpairedTrainConfig.num_samples
else:
num_samples = conf.Sig53CleanTrainConfig.num_samples

sig53_train = Sig53(
root,
train=True,
Expand All @@ -103,21 +97,18 @@ def main(root: str, impaired: bool):
use_signal_data=True,
)

num_workers=os.cpu_count() // 2
batch_size=int(np.min((num_samples // num_workers, 32)))
print(batch_size,'batch_size',num_workers,'num_workers')
# Create dataloaders"data
train_dataloader = DataLoader(
dataset=sig53_train,
batch_size=batch_size,
num_workers=num_workers,
batch_size=os.cpu_count(),
num_workers=os.cpu_count() // 2,
shuffle=True,
drop_last=True,
)
val_dataloader = DataLoader(
dataset=sig53_val,
batch_size=batch_size,
num_workers=num_workers,
batch_size=os.cpu_count(),
num_workers=os.cpu_count() // 2,
shuffle=False,
drop_last=True,
)
Expand Down
6 changes: 4 additions & 2 deletions torchsig/datasets/modulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,16 @@ def __init__(
)

if num_digital > 0 and num_ofdm > 0:
super(ModulationsDataset, self).__init__([digital_dataset, ofdm_dataset], **kwargs
)
super(ModulationsDataset, self).__init__([digital_dataset, ofdm_dataset], **kwargs)
# Torch's ConcatDataset should create this.

elif num_digital > 0:
super(ModulationsDataset, self).__init__([digital_dataset], **kwargs)
elif num_ofdm > 0:
super(ModulationsDataset, self).__init__([ofdm_dataset], **kwargs)
else:
raise ValueError("Input classes must contain at least 1 valid class")


def __getitem__(self, item):
return super(ModulationsDataset, self).__getitem__(item)
4 changes: 1 addition & 3 deletions torchsig/datasets/sig53.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def __init__(
cfg = getattr(conf, cfg)() # type: ignore

self.path = self.root / cfg.name
self.env = lmdb.Environment(
str(self.path).encode(), map_size=int(1e12), max_dbs=2, lock=False
)
self.env = lmdb.Environment(str(self.path).encode(), map_size=int(1e12), max_dbs=2, lock=False)
self.data_db = self.env.open_db(b"data")
self.label_db = self.env.open_db(b"label")
with self.env.begin(db=self.data_db) as data_txn:
Expand Down
2 changes: 0 additions & 2 deletions torchsig/datasets/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def __init__(
)
super(DigitalModulationDataset, self).__init__([const_dataset, fsk_dataset, gfsks_dataset])


class SyntheticDataset(SignalDataset):
def __init__(self, **kwargs) -> None:
super(SyntheticDataset, self).__init__(**kwargs)
Expand Down Expand Up @@ -921,7 +920,6 @@ def _generate_samples(self, item: Tuple) -> np.ndarray:
# scale the frequency map by the oversampling rate such that the tones
# are packed tighter around f=0 the larger the oversampling rate
const_oversampled = const / oversampling_rate
# print(f'const -> {const}, const_oversampled -> {const_oversampled} samples_per_symbol_recalculated -> {samples_per_symbol_recalculated}')

orig_state = np.random.get_state()
if not self.random_data:
Expand Down
4 changes: 0 additions & 4 deletions torchsig/datasets/wideband.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,6 @@ def __init__(
self.index = []
self.pregenerate = False
if pregenerate:
#print("Pregenerating dataset...")
for idx in tqdm(range(self.num_samples)):
self.index.append(self.__getitem__(idx))
self.pregenerate = pregenerate
Expand Down Expand Up @@ -768,7 +767,6 @@ def __init__(
**kwargs,
):
super(WidebandModulationsDataset, self).__init__(**kwargs)
#print(f'seed -> {seed}')
self.random_generator = np.random.default_rng(seed)
self.update_rng = False
self.seed = seed
Expand Down Expand Up @@ -927,8 +925,6 @@ def __gen_metadata__(self, modulation_list: List) -> pd.DataFrame:
def __getitem__(self, item: int) -> Tuple[np.ndarray, Any]:
# Initialize empty list of signal sources & signal descriptors
if not self.update_rng:
# rng = np.random.default_rng(os.getpid())
#print(f'pid -> {os.getpid()}, updated dataset')
self.random_generator = np.random.default_rng(os.getpid())
self.update_rng = True
signal_sources: List[SyntheticBurstSourceDataset] = []
Expand Down
21 changes: 4 additions & 17 deletions torchsig/datasets/wideband_sig53.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,13 @@ def __init__(
self.T = transform if transform else Identity()
self.TT = target_transform if target_transform else Identity()

cfg = (
"WidebandSig53"
+ ("Impaired" if impaired else "Clean")
+ ("Train" if train else "Val")
+ "Config"
)
cfg = ("WidebandSig53"+ ("Impaired" if impaired else "Clean") + ("Train" if train else "Val") + "Config")
cfg = getattr(conf, cfg)()

self.signal_desc_transform = ListTupleToDesc(
num_iq_samples=cfg.num_iq_samples, # type: ignore
class_list=self.modulation_list,
)
self.signal_desc_transform = ListTupleToDesc(num_iq_samples=cfg.num_iq_samples, class_list=self.modulation_list)

self.path = self.root / cfg.name # type: ignore
self.env = lmdb.Environment(
str(self.path).encode(), map_size=int(1e12), max_dbs=2, lock=False
)
self.env = lmdb.Environment(str(self.path).encode(), map_size=int(1e12), max_dbs=2, lock=False)
self.data_db = self.env.open_db(b"data")
self.label_db = self.env.open_db(b"label")
with self.env.begin(db=self.data_db) as data_txn:
Expand All @@ -138,10 +128,7 @@ def __getitem__(self, idx: int) -> tuple:
with self.env.begin(db=self.label_db) as label_txn:
label = pickle.loads(label_txn.get(encoded_idx))

signal = Signal(
data=create_signal_data(samples=iq_data),
metadata=self.signal_desc_transform(label),
)
signal = Signal(data=create_signal_data(samples=iq_data), metadata=self.signal_desc_transform(label))
signal = self.T(signal) # type: ignore
target = self.TT(signal["metadata"]) # type: ignore

Expand Down
6 changes: 1 addition & 5 deletions torchsig/transforms/target_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def generate_mask(

begin_height = int((meta["lower_freq"] + 0.5) * height)
end_height = int((meta["upper_freq"] + 0.5) * height)
# print(f' start/stop in the transform -> {meta["start"]} {meta["stop"]}')
begin_width = int(meta["start"] * width)
end_width = int(meta["stop"] * width)

Expand Down Expand Up @@ -398,7 +397,6 @@ def __call__(self, metadata: List[SignalMetadata]) -> np.ndarray:
if not is_rf_modulated_metadata(meta):
continue

# print(f'start/stop {meta["start"]} {meta["stop"]}')
meta = meta_bound_frequency(meta)
masks = generate_mask(meta, masks, meta["class_index"], 1.0, self.height, self.width)

Expand Down Expand Up @@ -1390,9 +1388,7 @@ def __call__(
metadata: List[SignalMetadata] = []
# Loop through SignalMetadata's, converting values of interest to tuples
for curr_tuple in list_tuple:
tup: Tuple[Any, ...] = tuple(
[l.numpy() if isinstance(l, torch.Tensor) else l for l in curr_tuple]
)
tup: Tuple[Any, ...] = tuple([l.numpy() if isinstance(l, torch.Tensor) else l for l in curr_tuple])
name, start, stop, cf, bw, snr = tup
meta: SignalMetadata = create_modulated_rf_metadata(
sample_rate=0 if not self.sample_rate else self.sample_rate,
Expand Down
2 changes: 0 additions & 2 deletions torchsig/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,6 @@ def check_freq_bounds(self, signal: Signal, new_rate: float) -> float:
else:
ret_rate = meta['upper_freq'] / .5
ret_list.append(ret_rate)
# print(f'adjusting resampling ratio new {ret_rate} old {new_rate}')
else:
ret_list.append(new_rate)

Expand Down Expand Up @@ -1681,7 +1680,6 @@ def check_freq_bounds(self, signal: Signal, freq_shift: float) -> float:
else:
freq_shift = .5 - meta['upper_freq']
ret_list.append(freq_shift)
# print(f'adjusting resampling ratio new {ret_rate} old {new_rate}')
else:
ret_list.append(freq_shift)

Expand Down
2 changes: 1 addition & 1 deletion torchsig/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import pywt
import torch
import pdb
import pdbgit


class Visualizer:
Expand Down
Loading

0 comments on commit cce5329

Please sign in to comment.