From 63c63e037db1c4f6dec673be4dd0e7e1b36f2ad5 Mon Sep 17 00:00:00 2001 From: gvanhoy Date: Mon, 14 Aug 2023 13:45:36 -0400 Subject: [PATCH] Adjustments to generation --- scripts/generate_wideband_sig53.py | 6 +++--- torchsig/datasets/sig53.py | 4 ++-- torchsig/utils/writer.py | 14 +++++++++++--- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/scripts/generate_wideband_sig53.py b/scripts/generate_wideband_sig53.py index 487eb49..4171602 100644 --- a/scripts/generate_wideband_sig53.py +++ b/scripts/generate_wideband_sig53.py @@ -23,8 +23,8 @@ def generate(root: str, configs: List[conf.WidebandSig53Config]): dataset_loader = DatasetLoader( wideband_ds, - num_workers=8, - batch_size=8, + num_workers=os.cpu_count() // 2, + batch_size=os.cpu_count() // 2, seed=12345678, collate_fn=collate_fn, ) @@ -45,7 +45,7 @@ def generate(root: str, configs: List[conf.WidebandSig53Config]): "--all", default=True, help="Generate all versions of wideband_sig53 dataset." ) @click.option( - "--qa", default=True, help="Generate only QA versions of wideband_sig53 dataset." + "--qa", default=False, help="Generate only QA versions of wideband_sig53 dataset." ) @click.option( "--impaired", diff --git a/torchsig/datasets/sig53.py b/torchsig/datasets/sig53.py index df8c400..9acef13 100644 --- a/torchsig/datasets/sig53.py +++ b/torchsig/datasets/sig53.py @@ -88,7 +88,7 @@ def __init__( self.path = self.root / cfg.name self.env = lmdb.Environment( - str(self.path).encode(), map_size=int(1e12), max_dbs=2, lock=False + str(self.path).encode(), map_size=int(4e12), max_dbs=2, lock=False ) self.data_db = self.env.open_db(b"data") self.label_db = self.env.open_db(b"label") @@ -101,7 +101,7 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> Tuple[np.ndarray, Any]: encoded_idx = pickle.dumps(idx) with self.env.begin(db=self.data_db) as data_txn: - iq_data = pickle.loads(data_txn.get(encoded_idx)).numpy() + iq_data = pickle.loads(data_txn.get(encoded_idx)) with self.env.begin(db=self.label_db) as label_txn: mod, snr = pickle.loads(label_txn.get(encoded_idx)) diff --git a/torchsig/utils/writer.py b/torchsig/utils/writer.py index aca130c..73e4dfe 100644 --- a/torchsig/utils/writer.py +++ b/torchsig/utils/writer.py @@ -53,10 +53,9 @@ def __init__( multiprocessing_context=torch.multiprocessing.get_context("fork"), collate_fn=collate_fn, ) - self.length = int(len(dataset) / batch_size) def __len__(self): - return self.length + return len(self.loader) def __next__(self): data, label = next(self.loader) @@ -98,8 +97,10 @@ def exists(self): def write(self, batch): data, labels = batch - with self.env.begin(write=True) as txn: + with self.env.begin() as txn: last_idx = txn.stat(db=self.data_db)["entries"] + + with self.env.begin(write=True) as txn: if isinstance(labels, tuple): for label_idx, label in enumerate(labels): txn.put( @@ -115,6 +116,13 @@ def write(self, batch): db=self.label_db, ) for element_idx in range(len(data)): + if not isinstance(data[element_idx], np.ndarray): + txn.put( + pickle.dumps(last_idx + element_idx), + pickle.dumps(data[element_idx].numpy()), + db=self.data_db, + ) + continue txn.put( pickle.dumps(last_idx + element_idx), pickle.dumps(data[element_idx]),