Skip to content

Commit

Permalink
Adjustments to generation
Browse files Browse the repository at this point in the history
  • Loading branch information
gvanhoy committed Aug 14, 2023
1 parent be92c61 commit 63c63e0
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
6 changes: 3 additions & 3 deletions scripts/generate_wideband_sig53.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def generate(root: str, configs: List[conf.WidebandSig53Config]):

dataset_loader = DatasetLoader(
wideband_ds,
num_workers=8,
batch_size=8,
num_workers=os.cpu_count() // 2,
batch_size=os.cpu_count() // 2,
seed=12345678,
collate_fn=collate_fn,
)
Expand All @@ -45,7 +45,7 @@ def generate(root: str, configs: List[conf.WidebandSig53Config]):
"--all", default=True, help="Generate all versions of wideband_sig53 dataset."
)
@click.option(
"--qa", default=True, help="Generate only QA versions of wideband_sig53 dataset."
"--qa", default=False, help="Generate only QA versions of wideband_sig53 dataset."
)
@click.option(
"--impaired",
Expand Down
4 changes: 2 additions & 2 deletions torchsig/datasets/sig53.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(

self.path = self.root / cfg.name
self.env = lmdb.Environment(
str(self.path).encode(), map_size=int(1e12), max_dbs=2, lock=False
str(self.path).encode(), map_size=int(4e12), max_dbs=2, lock=False
)
self.data_db = self.env.open_db(b"data")
self.label_db = self.env.open_db(b"label")
Expand All @@ -101,7 +101,7 @@ def __len__(self) -> int:
def __getitem__(self, idx: int) -> Tuple[np.ndarray, Any]:
encoded_idx = pickle.dumps(idx)
with self.env.begin(db=self.data_db) as data_txn:
iq_data = pickle.loads(data_txn.get(encoded_idx)).numpy()
iq_data = pickle.loads(data_txn.get(encoded_idx))

with self.env.begin(db=self.label_db) as label_txn:
mod, snr = pickle.loads(label_txn.get(encoded_idx))
Expand Down
14 changes: 11 additions & 3 deletions torchsig/utils/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ def __init__(
multiprocessing_context=torch.multiprocessing.get_context("fork"),
collate_fn=collate_fn,
)
self.length = int(len(dataset) / batch_size)

def __len__(self):
return self.length
return len(self.loader)

def __next__(self):
data, label = next(self.loader)
Expand Down Expand Up @@ -98,8 +97,10 @@ def exists(self):

def write(self, batch):
data, labels = batch
with self.env.begin(write=True) as txn:
with self.env.begin() as txn:
last_idx = txn.stat(db=self.data_db)["entries"]

with self.env.begin(write=True) as txn:
if isinstance(labels, tuple):
for label_idx, label in enumerate(labels):
txn.put(
Expand All @@ -115,6 +116,13 @@ def write(self, batch):
db=self.label_db,
)
for element_idx in range(len(data)):
if not isinstance(data[element_idx], np.ndarray):
txn.put(
pickle.dumps(last_idx + element_idx),
pickle.dumps(data[element_idx].numpy()),
db=self.data_db,
)
continue
txn.put(
pickle.dumps(last_idx + element_idx),
pickle.dumps(data[element_idx]),
Expand Down

0 comments on commit 63c63e0

Please sign in to comment.