diff --git a/Dockerfile b/Dockerfile index 51016d6..303fc66 100755 --- a/Dockerfile +++ b/Dockerfile @@ -9,8 +9,6 @@ RUN apt-get update && apt-get install -y \ rsync \ libgl1-mesa-glx - - ADD torchsig/ /build/torchsig ADD pyproject.toml /build/pyproject.toml diff --git a/scripts/generate_sig53.py b/scripts/generate_sig53.py index 8e536e1..704a220 100755 --- a/scripts/generate_sig53.py +++ b/scripts/generate_sig53.py @@ -7,12 +7,15 @@ import numpy as np -def generate(path: str, configs: List[conf.Sig53Config], num_workers: int): +def generate(path: str, configs: List[conf.Sig53Config], num_workers: int, num_samples_override: int): + num_samples = config.num_samples if num_samples == 0 else num_samples for config in configs: + num_samples = config.num_samples if num_samples_override <=0 else num_samples_override batch_size = int(np.min((config.num_samples // num_workers, 32))) + print(f'batch_size -> {batch_size} num_samples -> {num_samples}') ds = ModulationsDataset( level=config.level, - num_samples=config.num_samples, + num_samples=num_samples, num_iq_samples=config.num_iq_samples, use_class_idx=config.use_class_idx, include_snr=config.include_snr, @@ -38,9 +41,11 @@ def generate(path: str, configs: List[conf.Sig53Config], num_workers: int): @click.option("--root", default="sig53", help="Path to generate sig53 datasets") @click.option("--all", default=True, help="Generate all versions of sig53 dataset.") @click.option("--qa", default=False, help="Generate only QA versions of sig53 dataset.") +@click.option("--qa", default=False, help="Generate only QA versions of sig53 dataset.") +@click.option("--num-samples", default=-1, help="Override for number of dataset samples.") @click.option("--num-workers", "num_workers", default=os.cpu_count() // 2, help="Define number of workers for both DatasetLoader and DatasetCreator") @click.option("--impaired", default=False, help="Generate impaired dataset. Ignored if --all=True (default)") -def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int): +def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int, num_samples): if not os.path.isdir(root): os.mkdir(root) @@ -55,18 +60,19 @@ def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int): conf.Sig53ImpairedValQAConfig, ] if all: - generate(root, configs[:4], num_workers) + generate(root, configs[:4], num_workers, num_samples) return - if qa: - generate(root, configs[4:], num_workers) + elif qa: + generate(root, configs[4:], num_workers, num_samples) return - if impaired: - generate(root, configs[2:4], num_workers) + elif impaired: + generate(root, configs[2:4], num_workers, num_samples) return - generate(root, configs[:2], num_workers) + else: + generate(root, configs[:2], num_workers, num_samples) if __name__ == "__main__": diff --git a/torchsig/utils/writer.py b/torchsig/utils/writer.py index aaaf4f7..d940ee3 100755 --- a/torchsig/utils/writer.py +++ b/torchsig/utils/writer.py @@ -103,7 +103,7 @@ def write(self, batch): for label_idx, label in enumerate(labels): txn.put( pickle.dumps(last_idx + label_idx), - pickle.dumps(tuple(label)), + pickle.dumps(tuple(label.numpy())), db=self.label_db, ) if isinstance(labels, list): @@ -116,7 +116,7 @@ def write(self, batch): for element_idx in range(len(data)): txn.put( pickle.dumps(last_idx + element_idx), - pickle.dumps(data[element_idx]), + pickle.dumps(data[element_idx].numpy()), db=self.data_db, )