Skip to content

Commit

Permalink
Adding an option to compress (#216)
Browse files Browse the repository at this point in the history
* Adding an option to compress

* Adding compression to Wideband
  • Loading branch information
gvanhoy authored Sep 20, 2023
1 parent 2e7d2c3 commit ce9b03e
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 12 deletions.
17 changes: 9 additions & 8 deletions examples/00_sig53_dataset.ipynb

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion torchsig/datasets/sig53.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
train: bool = True,
impaired: bool = True,
eb_no: bool = False,
compressed: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
use_signal_data: bool = False,
Expand All @@ -71,6 +72,7 @@ def __init__(
self.train = train
self.impaired = impaired
self.eb_no = eb_no
self.compressed = compressed
self.use_signal_data = use_signal_data

self.T = transform if transform else Identity()
Expand Down Expand Up @@ -101,7 +103,11 @@ def __len__(self) -> int:
def __getitem__(self, idx: int) -> Tuple[np.ndarray, Any]:
encoded_idx = pickle.dumps(idx)
with self.env.begin(db=self.data_db) as data_txn:
iq_data = pickle.loads(data_txn.get(encoded_idx))
iq_data: np.ndarray = pickle.loads(data_txn.get(encoded_idx))
if self.compressed:
iq_data = iq_data.astype(np.float64).view(np.complex128) / (
np.iinfo(np.int16).max - 1
)

with self.env.begin(db=self.label_db) as label_txn:
mod, snr = pickle.loads(label_txn.get(encoded_idx))
Expand Down
6 changes: 6 additions & 0 deletions torchsig/datasets/wideband_sig53.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
root: str,
train: bool = True,
impaired: bool = True,
compressed: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
use_signal_data: bool = True,
Expand All @@ -105,6 +106,7 @@ def __init__(

self.train = train
self.impaired = impaired
self.compressed = compressed

self.T = transform if transform else Identity()
self.TT = target_transform if target_transform else Identity()
Expand Down Expand Up @@ -139,6 +141,10 @@ def __getitem__(self, idx: int) -> tuple:
encoded_idx = pickle.dumps(idx)
with self.env.begin(db=self.data_db) as data_txn:
iq_data: np.ndarray = pickle.loads(data_txn.get(encoded_idx))
if self.compressed:
iq_data = iq_data.astype(np.float64).view(np.complex128) / (
np.iinfo(np.int16).max - 1
)

with self.env.begin(db=self.label_db) as label_txn:
label = pickle.loads(label_txn.get(encoded_idx))
Expand Down
36 changes: 33 additions & 3 deletions torchsig/utils/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.utils.data import DataLoader
from typing import Callable, Optional
from functools import partial
from io import BytesIO
import numpy as np
import pickle
import random
Expand Down Expand Up @@ -80,9 +81,10 @@ class LMDBDatasetWriter(DatasetWriter):
path (str): directory in which to keep the database files
"""

def __init__(self, path: str, *args, **kwargs):
def __init__(self, path: str, compress: bool = False, *args, **kwargs):
super(LMDBDatasetWriter, self).__init__(*args, **kwargs)
self.path = path
self.compress = compress
self.env = lmdb.Environment(path, subdir=True, map_size=int(4e12), max_dbs=2)
self.data_db = self.env.open_db(b"data")
self.label_db = self.env.open_db(b"label")
Expand All @@ -93,6 +95,26 @@ def exists(self):
return True
return False

@staticmethod
def _compress(
data: np.ndarray, storage_type: np.dtype = np.dtype(np.int16)
) -> np.ndarray:
if storage_type == np.float64:
return data

floats = data.view(np.float64)
max_amp = np.max(np.abs(floats))
normalized = (np.iinfo(storage_type).max - 1) * floats / max_amp
digitized: np.ndarray = (
np.digitize(
normalized,
np.arange(np.iinfo(storage_type).min, np.iinfo(storage_type).max),
right=True,
)
- np.iinfo(storage_type).max
)
return digitized.astype(storage_type)

def write(self, batch):
data, labels = batch
with self.env.begin() as txn:
Expand All @@ -115,15 +137,23 @@ def write(self, batch):
)
for element_idx in range(len(data)):
if not isinstance(data[element_idx], np.ndarray):
compressed = self._compress(
data[element_idx].numpy(),
np.int16 if self.compress else np.float64,
)
txn.put(
pickle.dumps(last_idx + element_idx),
pickle.dumps(data[element_idx].numpy()),
pickle.dumps(compressed),
db=self.data_db,
)
continue
compressed = self._compress(
data[element_idx],
np.int16 if self.compress else np.float64,
)
txn.put(
pickle.dumps(last_idx + element_idx),
pickle.dumps(data[element_idx]),
pickle.dumps(compressed),
db=self.data_db,
)

Expand Down

0 comments on commit ce9b03e

Please sign in to comment.