Skip to content

Commit

Permalink
fixing the .numpy() call in utils/writer.py
Browse files Browse the repository at this point in the history
	modified:   Dockerfile
	modified:   scripts/generate_sig53.py
	modified:   torchsig/utils/writer.py
  • Loading branch information
pvallance committed Jul 10, 2024
1 parent 78bbad0 commit 691302c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
2 changes: 0 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ RUN apt-get update && apt-get install -y \
rsync \
libgl1-mesa-glx



ADD torchsig/ /build/torchsig

ADD pyproject.toml /build/pyproject.toml
Expand Down
24 changes: 15 additions & 9 deletions scripts/generate_sig53.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import numpy as np


def generate(path: str, configs: List[conf.Sig53Config], num_workers: int):
def generate(path: str, configs: List[conf.Sig53Config], num_workers: int, num_samples_override: int):
num_samples = config.num_samples if num_samples == 0 else num_samples
for config in configs:
num_samples = config.num_samples if num_samples_override <=0 else num_samples_override
batch_size = int(np.min((config.num_samples // num_workers, 32)))
print(f'batch_size -> {batch_size} num_samples -> {num_samples}')
ds = ModulationsDataset(
level=config.level,
num_samples=config.num_samples,
num_samples=num_samples,
num_iq_samples=config.num_iq_samples,
use_class_idx=config.use_class_idx,
include_snr=config.include_snr,
Expand All @@ -38,9 +41,11 @@ def generate(path: str, configs: List[conf.Sig53Config], num_workers: int):
@click.option("--root", default="sig53", help="Path to generate sig53 datasets")
@click.option("--all", default=True, help="Generate all versions of sig53 dataset.")
@click.option("--qa", default=False, help="Generate only QA versions of sig53 dataset.")
@click.option("--qa", default=False, help="Generate only QA versions of sig53 dataset.")
@click.option("--num-samples", default=-1, help="Override for number of dataset samples.")
@click.option("--num-workers", "num_workers", default=os.cpu_count() // 2, help="Define number of workers for both DatasetLoader and DatasetCreator")
@click.option("--impaired", default=False, help="Generate impaired dataset. Ignored if --all=True (default)")
def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int):
def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int, num_samples):
if not os.path.isdir(root):
os.mkdir(root)

Expand All @@ -55,18 +60,19 @@ def main(root: str, all: bool, qa: bool, impaired: bool, num_workers: int):
conf.Sig53ImpairedValQAConfig,
]
if all:
generate(root, configs[:4], num_workers)
generate(root, configs[:4], num_workers, num_samples)
return

if qa:
generate(root, configs[4:], num_workers)
elif qa:
generate(root, configs[4:], num_workers, num_samples)
return

if impaired:
generate(root, configs[2:4], num_workers)
elif impaired:
generate(root, configs[2:4], num_workers, num_samples)
return

generate(root, configs[:2], num_workers)
else:
generate(root, configs[:2], num_workers, num_samples)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions torchsig/utils/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def write(self, batch):
for label_idx, label in enumerate(labels):
txn.put(
pickle.dumps(last_idx + label_idx),
pickle.dumps(tuple(label)),
pickle.dumps(tuple(label.numpy())),
db=self.label_db,
)
if isinstance(labels, list):
Expand All @@ -116,7 +116,7 @@ def write(self, batch):
for element_idx in range(len(data)):
txn.put(
pickle.dumps(last_idx + element_idx),
pickle.dumps(data[element_idx]),
pickle.dumps(data[element_idx].numpy()),
db=self.data_db,
)

Expand Down

0 comments on commit 691302c

Please sign in to comment.