From cce5329e45a881152718a82c4a1d84708808ac19 Mon Sep 17 00:00:00 2001 From: pvallance Date: Sat, 13 Jul 2024 16:19:56 +0000 Subject: [PATCH] removed unused print statements add collate_fn in generate_sig53.py to avoid torch Dataloader from converting the labels to tensors. modified: scripts/generate_sig53.py modified: scripts/generate_wideband_sig53.py modified: scripts/train_sig53.py modified: torchsig/datasets/modulations.py modified: torchsig/datasets/sig53.py modified: torchsig/datasets/synthetic.py modified: torchsig/datasets/wideband.py modified: torchsig/datasets/wideband_sig53.py modified: torchsig/transforms/target_transforms.py modified: torchsig/transforms/transforms.py modified: torchsig/utils/visualize.py modified: torchsig/utils/writer.py --- scripts/generate_sig53.py | 22 ++--- scripts/generate_wideband_sig53.py | 103 +++++++++++++++++------ scripts/train_sig53.py | 17 +--- torchsig/datasets/modulations.py | 6 +- torchsig/datasets/sig53.py | 4 +- torchsig/datasets/synthetic.py | 2 - torchsig/datasets/wideband.py | 4 - torchsig/datasets/wideband_sig53.py | 21 +---- torchsig/transforms/target_transforms.py | 6 +- torchsig/transforms/transforms.py | 2 - torchsig/utils/visualize.py | 2 +- torchsig/utils/writer.py | 6 +- 12 files changed, 101 insertions(+), 94 deletions(-) diff --git a/scripts/generate_sig53.py b/scripts/generate_sig53.py index 704a220..47fc2b3 100755 --- a/scripts/generate_sig53.py +++ b/scripts/generate_sig53.py @@ -6,13 +6,16 @@ import os import numpy as np +def collate_fn(batch): + return tuple(zip(*batch)) + def generate(path: str, configs: List[conf.Sig53Config], num_workers: int, num_samples_override: int): - num_samples = config.num_samples if num_samples == 0 else num_samples + num_samples = config.num_samples if num_samples_override == 0 else num_samples_override for config in configs: num_samples = config.num_samples if num_samples_override <=0 else num_samples_override batch_size = int(np.min((config.num_samples // num_workers, 32))) - print(f'batch_size -> {batch_size} num_samples -> {num_samples}') + print(f'batch_size -> {batch_size} num_samples -> {num_samples}, config -> {config}') ds = ModulationsDataset( level=config.level, num_samples=num_samples, @@ -21,19 +24,8 @@ def generate(path: str, configs: List[conf.Sig53Config], num_workers: int, num_s include_snr=config.include_snr, eb_no=config.eb_no, ) - loader = DatasetLoader( - ds, - seed=12345678, - num_workers=num_workers, - batch_size=batch_size, - ) - creator = DatasetCreator( - ds, - seed=12345678, - path="{}".format(os.path.join(path, config.name)), - loader=loader, - num_workers=num_workers, - ) + dataset_loader = DatasetLoader(ds, seed=12345678, collate_fn=collate_fn, num_workers=num_workers, batch_size=batch_size) + creator = DatasetCreator(ds, seed=12345678, path="{}".format(os.path.join(path, config.name)), loader=dataset_loader, num_workers=num_workers) creator.create() diff --git a/scripts/generate_wideband_sig53.py b/scripts/generate_wideband_sig53.py index d5d9fcb..5787445 100755 --- a/scripts/generate_wideband_sig53.py +++ b/scripts/generate_wideband_sig53.py @@ -2,41 +2,88 @@ from torchsig.utils.writer import DatasetCreator, DatasetLoader from torchsig.datasets.wideband import WidebandModulationsDataset from torchsig.datasets import conf +from torchsig.transforms.transforms import * from typing import List import click import os -import numpy as np +import numpy as np def collate_fn(batch): return tuple(zip(*batch)) -def generate(root: str, configs: List[conf.WidebandSig53Config], num_workers: int): +modulation_list = ["ook", + "bpsk", + "4pam", + "4ask", + "qpsk", + "8pam", + "8ask", + "8psk", + "16qam", + "16pam", + "16ask", + "16psk", + "32qam", + "32qam_cross", + "32pam", + "32ask", + "32psk", + "64qam", + "64pam", + "64ask", + "64psk", + "128qam_cross", + "256qam", + "512qam_cross", + "1024qam", + "2fsk", + "2gfsk", + "2msk", + "2gmsk", + "4fsk", + "4gfsk", + "4msk", + "4gmsk", + "8fsk", + "8gfsk", + "8msk", + "8gmsk", + "16fsk", + "16gfsk", + "16msk", + "16gmsk", + "ofdm-64", + "ofdm-72", + "ofdm-128", + "ofdm-180", + "ofdm-256", + "ofdm-300", + "ofdm-512", + "ofdm-600", + "ofdm-900", + "ofdm-1024", + "ofdm-1200", + "ofdm-2048", + ] + +def generate(root: str, configs: List[conf.WidebandSig53Config], num_workers: int, num_samples_override: int): for config in configs: - batch_size = int(np.min((config.num_samples // num_workers, 32))) + num_samples = config.num_samples if num_samples_override <=0 else num_samples + batch_size = int(np.min((num_samples // num_workers, 32))) + print(f'batch_size -> {batch_size} num_samples -> {num_samples}, config -> {config}') wideband_ds = WidebandModulationsDataset( level=config.level, num_iq_samples=config.num_iq_samples, - num_samples=config.num_samples, + num_samples=num_samples, + modulation_list=modulation_list, target_transform=DescToListTuple(), seed=config.seed, ) - dataset_loader = DatasetLoader( - wideband_ds, - seed=12345678, - collate_fn=collate_fn, - num_workers=num_workers, - batch_size=batch_size, - ) - creator = DatasetCreator( - wideband_ds, - seed=12345678, - path=os.path.join(root, config.name), - loader=dataset_loader, - num_workers=num_workers, - ) + dataset_loader = DatasetLoader(wideband_ds, seed=12345678, collate_fn=collate_fn, num_workers=num_workers, batch_size=batch_size) + creator = DatasetCreator(wideband_ds, seed=12345678, num_workers=num_workers, path=os.path.join(root, config.name), loader=dataset_loader,) creator.create() @@ -44,9 +91,10 @@ def generate(root: str, configs: List[conf.WidebandSig53Config], num_workers: in @click.option("--root", default="wideband_sig53", help="Path to generate wideband_sig53 datasets") @click.option("--all", default=True, help="Generate all versions of wideband_sig53 dataset.") @click.option("--qa", default=False, help="Generate only QA versions of wideband_sig53 dataset.") +@click.option("--num-samples", default=-1, help="Override for number of dataset samples.") +@click.option("--impaired", default=False, help="Generate impaired dataset. Ignored if --all=True (default)",) @click.option("--num-workers", "num_workers", default=os.cpu_count() // 2, help="Define number of workers for both DatasetLoader and DatasetCreator") -@click.option("--impaired", default=False, help="Generate impaired dataset. Ignored if --all=True (default)") -def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int): +def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int, num_samples: int): if not os.path.isdir(root): os.mkdir(root) @@ -60,20 +108,21 @@ def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int): conf.WidebandSig53ImpairedTrainQAConfig, conf.WidebandSig53ImpairedValQAConfig, ] - + if all: - generate(root, configs[:4], num_workers) + generate(root, configs, num_workers, num_samples) return - if qa: - generate(root, configs[4:], num_workers) + elif qa: + generate(root, configs[4:], num_workers, num_samples) return - if impaired: - generate(root, configs[2:4], num_workers) + elif impaired: + generate(root, configs[2:], num_workers, num_samples) return - generate(root, configs[:2], num_workers) + else: + generate(root, configs[:2], num_workers, num_samples) if __name__ == "__main__": diff --git a/scripts/train_sig53.py b/scripts/train_sig53.py index 29f59a3..8888c7a 100755 --- a/scripts/train_sig53.py +++ b/scripts/train_sig53.py @@ -11,7 +11,6 @@ from sklearn.metrics import classification_report from torchsig.utils.cm_plotter import plot_confusion_matrix from torchsig.datasets.sig53 import Sig53 -from torchsig.datasets import conf from torch.utils.data import DataLoader from matplotlib import pyplot as plt from torch import optim @@ -80,11 +79,6 @@ def main(root: str, impaired: bool): ) target_transform = DescToClassIndex(class_list=class_list) - if impaired == True: - num_samples = conf.Sig53ImpairedTrainConfig.num_samples - else: - num_samples = conf.Sig53CleanTrainConfig.num_samples - sig53_train = Sig53( root, train=True, @@ -103,21 +97,18 @@ def main(root: str, impaired: bool): use_signal_data=True, ) - num_workers=os.cpu_count() // 2 - batch_size=int(np.min((num_samples // num_workers, 32))) - print(batch_size,'batch_size',num_workers,'num_workers') # Create dataloaders"data train_dataloader = DataLoader( dataset=sig53_train, - batch_size=batch_size, - num_workers=num_workers, + batch_size=os.cpu_count(), + num_workers=os.cpu_count() // 2, shuffle=True, drop_last=True, ) val_dataloader = DataLoader( dataset=sig53_val, - batch_size=batch_size, - num_workers=num_workers, + batch_size=os.cpu_count(), + num_workers=os.cpu_count() // 2, shuffle=False, drop_last=True, ) diff --git a/torchsig/datasets/modulations.py b/torchsig/datasets/modulations.py index b2e7995..3d338dd 100755 --- a/torchsig/datasets/modulations.py +++ b/torchsig/datasets/modulations.py @@ -260,8 +260,9 @@ def __init__( ) if num_digital > 0 and num_ofdm > 0: - super(ModulationsDataset, self).__init__([digital_dataset, ofdm_dataset], **kwargs -) + super(ModulationsDataset, self).__init__([digital_dataset, ofdm_dataset], **kwargs) + # Torch's ConcatDataset should create this. + elif num_digital > 0: super(ModulationsDataset, self).__init__([digital_dataset], **kwargs) elif num_ofdm > 0: @@ -269,5 +270,6 @@ def __init__( else: raise ValueError("Input classes must contain at least 1 valid class") + def __getitem__(self, item): return super(ModulationsDataset, self).__getitem__(item) diff --git a/torchsig/datasets/sig53.py b/torchsig/datasets/sig53.py index afa1a82..420c19c 100755 --- a/torchsig/datasets/sig53.py +++ b/torchsig/datasets/sig53.py @@ -84,9 +84,7 @@ def __init__( cfg = getattr(conf, cfg)() # type: ignore self.path = self.root / cfg.name - self.env = lmdb.Environment( - str(self.path).encode(), map_size=int(1e12), max_dbs=2, lock=False - ) + self.env = lmdb.Environment(str(self.path).encode(), map_size=int(1e12), max_dbs=2, lock=False) self.data_db = self.env.open_db(b"data") self.label_db = self.env.open_db(b"label") with self.env.begin(db=self.data_db) as data_txn: diff --git a/torchsig/datasets/synthetic.py b/torchsig/datasets/synthetic.py index a3fe95e..e977554 100755 --- a/torchsig/datasets/synthetic.py +++ b/torchsig/datasets/synthetic.py @@ -211,7 +211,6 @@ def __init__( ) super(DigitalModulationDataset, self).__init__([const_dataset, fsk_dataset, gfsks_dataset]) - class SyntheticDataset(SignalDataset): def __init__(self, **kwargs) -> None: super(SyntheticDataset, self).__init__(**kwargs) @@ -921,7 +920,6 @@ def _generate_samples(self, item: Tuple) -> np.ndarray: # scale the frequency map by the oversampling rate such that the tones # are packed tighter around f=0 the larger the oversampling rate const_oversampled = const / oversampling_rate - # print(f'const -> {const}, const_oversampled -> {const_oversampled} samples_per_symbol_recalculated -> {samples_per_symbol_recalculated}') orig_state = np.random.get_state() if not self.random_data: diff --git a/torchsig/datasets/wideband.py b/torchsig/datasets/wideband.py index d96d9f6..7985366 100755 --- a/torchsig/datasets/wideband.py +++ b/torchsig/datasets/wideband.py @@ -633,7 +633,6 @@ def __init__( self.index = [] self.pregenerate = False if pregenerate: - #print("Pregenerating dataset...") for idx in tqdm(range(self.num_samples)): self.index.append(self.__getitem__(idx)) self.pregenerate = pregenerate @@ -768,7 +767,6 @@ def __init__( **kwargs, ): super(WidebandModulationsDataset, self).__init__(**kwargs) - #print(f'seed -> {seed}') self.random_generator = np.random.default_rng(seed) self.update_rng = False self.seed = seed @@ -927,8 +925,6 @@ def __gen_metadata__(self, modulation_list: List) -> pd.DataFrame: def __getitem__(self, item: int) -> Tuple[np.ndarray, Any]: # Initialize empty list of signal sources & signal descriptors if not self.update_rng: - # rng = np.random.default_rng(os.getpid()) - #print(f'pid -> {os.getpid()}, updated dataset') self.random_generator = np.random.default_rng(os.getpid()) self.update_rng = True signal_sources: List[SyntheticBurstSourceDataset] = [] diff --git a/torchsig/datasets/wideband_sig53.py b/torchsig/datasets/wideband_sig53.py index 6a1784c..4c2afc0 100755 --- a/torchsig/datasets/wideband_sig53.py +++ b/torchsig/datasets/wideband_sig53.py @@ -105,23 +105,13 @@ def __init__( self.T = transform if transform else Identity() self.TT = target_transform if target_transform else Identity() - cfg = ( - "WidebandSig53" - + ("Impaired" if impaired else "Clean") - + ("Train" if train else "Val") - + "Config" - ) + cfg = ("WidebandSig53"+ ("Impaired" if impaired else "Clean") + ("Train" if train else "Val") + "Config") cfg = getattr(conf, cfg)() - self.signal_desc_transform = ListTupleToDesc( - num_iq_samples=cfg.num_iq_samples, # type: ignore - class_list=self.modulation_list, - ) + self.signal_desc_transform = ListTupleToDesc(num_iq_samples=cfg.num_iq_samples, class_list=self.modulation_list) self.path = self.root / cfg.name # type: ignore - self.env = lmdb.Environment( - str(self.path).encode(), map_size=int(1e12), max_dbs=2, lock=False - ) + self.env = lmdb.Environment(str(self.path).encode(), map_size=int(1e12), max_dbs=2, lock=False) self.data_db = self.env.open_db(b"data") self.label_db = self.env.open_db(b"label") with self.env.begin(db=self.data_db) as data_txn: @@ -138,10 +128,7 @@ def __getitem__(self, idx: int) -> tuple: with self.env.begin(db=self.label_db) as label_txn: label = pickle.loads(label_txn.get(encoded_idx)) - signal = Signal( - data=create_signal_data(samples=iq_data), - metadata=self.signal_desc_transform(label), - ) + signal = Signal(data=create_signal_data(samples=iq_data), metadata=self.signal_desc_transform(label)) signal = self.T(signal) # type: ignore target = self.TT(signal["metadata"]) # type: ignore diff --git a/torchsig/transforms/target_transforms.py b/torchsig/transforms/target_transforms.py index 6932101..350b5b4 100755 --- a/torchsig/transforms/target_transforms.py +++ b/torchsig/transforms/target_transforms.py @@ -55,7 +55,6 @@ def generate_mask( begin_height = int((meta["lower_freq"] + 0.5) * height) end_height = int((meta["upper_freq"] + 0.5) * height) - # print(f' start/stop in the transform -> {meta["start"]} {meta["stop"]}') begin_width = int(meta["start"] * width) end_width = int(meta["stop"] * width) @@ -398,7 +397,6 @@ def __call__(self, metadata: List[SignalMetadata]) -> np.ndarray: if not is_rf_modulated_metadata(meta): continue - # print(f'start/stop {meta["start"]} {meta["stop"]}') meta = meta_bound_frequency(meta) masks = generate_mask(meta, masks, meta["class_index"], 1.0, self.height, self.width) @@ -1390,9 +1388,7 @@ def __call__( metadata: List[SignalMetadata] = [] # Loop through SignalMetadata's, converting values of interest to tuples for curr_tuple in list_tuple: - tup: Tuple[Any, ...] = tuple( - [l.numpy() if isinstance(l, torch.Tensor) else l for l in curr_tuple] - ) + tup: Tuple[Any, ...] = tuple([l.numpy() if isinstance(l, torch.Tensor) else l for l in curr_tuple]) name, start, stop, cf, bw, snr = tup meta: SignalMetadata = create_modulated_rf_metadata( sample_rate=0 if not self.sample_rate else self.sample_rate, diff --git a/torchsig/transforms/transforms.py b/torchsig/transforms/transforms.py index 45e0d1e..381e2c2 100755 --- a/torchsig/transforms/transforms.py +++ b/torchsig/transforms/transforms.py @@ -584,7 +584,6 @@ def check_freq_bounds(self, signal: Signal, new_rate: float) -> float: else: ret_rate = meta['upper_freq'] / .5 ret_list.append(ret_rate) - # print(f'adjusting resampling ratio new {ret_rate} old {new_rate}') else: ret_list.append(new_rate) @@ -1681,7 +1680,6 @@ def check_freq_bounds(self, signal: Signal, freq_shift: float) -> float: else: freq_shift = .5 - meta['upper_freq'] ret_list.append(freq_shift) - # print(f'adjusting resampling ratio new {ret_rate} old {new_rate}') else: ret_list.append(freq_shift) diff --git a/torchsig/utils/visualize.py b/torchsig/utils/visualize.py index 032ad15..d1cd35b 100755 --- a/torchsig/utils/visualize.py +++ b/torchsig/utils/visualize.py @@ -8,7 +8,7 @@ import numpy as np import pywt import torch -import pdb +import pdbgit class Visualizer: diff --git a/torchsig/utils/writer.py b/torchsig/utils/writer.py index d940ee3..67bd54d 100755 --- a/torchsig/utils/writer.py +++ b/torchsig/utils/writer.py @@ -103,7 +103,7 @@ def write(self, batch): for label_idx, label in enumerate(labels): txn.put( pickle.dumps(last_idx + label_idx), - pickle.dumps(tuple(label.numpy())), + pickle.dumps(tuple(label)), db=self.label_db, ) if isinstance(labels, list): @@ -116,7 +116,7 @@ def write(self, batch): for element_idx in range(len(data)): txn.put( pickle.dumps(last_idx + element_idx), - pickle.dumps(data[element_idx].numpy()), + pickle.dumps(data[element_idx]), db=self.data_db, ) @@ -146,7 +146,7 @@ def __init__( ) -> None: self.loader = DatasetLoader(dataset=dataset, seed=seed, num_workers=num_workers, batch_size=batch_size) self.loader = self.loader if not loader else loader - self.writer = LMDBDatasetWriter(path, map_size=1e12) + self.writer = LMDBDatasetWriter(path, map_size=1e13) self.writer = self.writer if not writer else writer # type: ignore self.path = path