-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
610 additions
and
432 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "579e425b-e5de-4fdc-9908-ed8706d57194", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"# Example 00 - The Official Sig53 Dataset\n", | ||
"This notebook walks through an example of how the official Sig53 dataset can be instantiated and analyzed." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0d636a9e-55c1-47a1-bc20-9c472acecc3b", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"----\n", | ||
"### Import Libraries\n", | ||
"First, import all the necessary public libraries as well as a few classes from the `torchsig` toolkit." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "edd181f0-893f-4646-8d7a-2fe2ee2280f6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from torchsig.utils.visualize import IQVisualizer, SpectrogramVisualizer\n", | ||
"from torchsig.utils.dataset import SignalDataset\n", | ||
"from torchsig.datasets.sig53 import Sig53\n", | ||
"from torch.utils.data import DataLoader\n", | ||
"from matplotlib import pyplot as plt\n", | ||
"from typing import List\n", | ||
"from tqdm import tqdm\n", | ||
"import numpy as np" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9d511e6b-7670-473b-a962-c08a9d341ec8", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"----\n", | ||
"### Instantiate Sig53 Dataset\n", | ||
"To instantiate the Sig53 dataset, several parameters are given to the imported `Sig53` class. These paramters are:\n", | ||
"- `root` ~ A string to specify the root directory of where to instantiate and/or read an existing Sig53 dataset\n", | ||
"- `train` ~ A boolean to specify if the Sig53 dataset should be the training (True) or validation (False) sets\n", | ||
"- `impaired` ~ A boolean to specify if the Sig53 dataset should be the clean version or the impaired version\n", | ||
"- `transform` ~ Optionally, pass in any data transforms here if the dataset will be used in an ML training pipeline\n", | ||
"- `target_transform` ~ Optionally, pass in any target transforms here if the dataset will be used in an ML training pipeline\n", | ||
"\n", | ||
"A combination of the `train` and the `impaired` booleans determines which of the four (4) distinct Sig53 datasets will be instantiated:\n", | ||
"- `train=True` & `impaired=False` = Clean training set of 1.06M examples\n", | ||
"- `train=True` & `impaired=True` = Impaired training set of 5.3M examples\n", | ||
"- `train=False` & `impaired=False` = Clean validation set of 106k examples\n", | ||
"- `train=False` & `impaired=True` = Impaired validation set of 106k examples\n", | ||
"\n", | ||
"The final option of the impaired validation set is the dataset to be used when reporting any results with the official Sig53 dataset.\n", | ||
"\n", | ||
"Additional optional parameters of potential interest are:\n", | ||
"- `regenerate` ~ A boolean specifying if the dataset should be regenerated even if an existing dataset is detected (Default: False)\n", | ||
"- `eb_no` ~ A boolean specifying if the SNR should be defined as Eb/No if True (making higher order modulations more powerful) or as Es/No if False (Defualt: False)\n", | ||
"- `use_signal_data` ~ A boolean specifying if the data and target information should be converted to `SignalData` objects as they are read in (Default: False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ee772ec3-c2b8-4cde-af9a-b1284df09342", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Specify Sig53 Options\n", | ||
"root = \"sig53/\"\n", | ||
"train = False\n", | ||
"impaired = False\n", | ||
"transform = None\n", | ||
"target_transform = None\n", | ||
"\n", | ||
"# Instantiate the Sig53 Dataset\n", | ||
"sig53 = Sig53(\n", | ||
" root=root,\n", | ||
" train=train,\n", | ||
" impaired=impaired,\n", | ||
" transform=transform,\n", | ||
" target_transform=target_transform,\n", | ||
")\n", | ||
"\n", | ||
"# Retrieve a sample and print out information\n", | ||
"idx = np.random.randint(len(sig53))\n", | ||
"data, (label, snr) = sig53[idx]\n", | ||
"print(\"Dataset length: {}\".format(len(sig53)))\n", | ||
"print(\"Data shape: {}\".format(data.shape))\n", | ||
"print(\"Label Index: {}\".format(label))\n", | ||
"print(\"Label Class: {}\".format(Sig53.convert_idx_to_name(label)))\n", | ||
"print(\"SNR: {}\".format(snr))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "80db34ff-80c2-49a0-96f3-d206cb307809", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"----\n", | ||
"### Plot Subset to Verify\n", | ||
"The `IQVisualizer` and the `SpectrogramVisualizer` can be passed a `Dataloader` and plot visualizations of the dataset. The `batch_size` of the `DataLoader` determines how many examples to plot for each iteration over the visualizer. Note that the dataset itself can be indexed and plotted sequentially using any familiar python plotting tools as an alternative plotting method to using the `torchsig` `Visualizer` as shown below." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b6b1d1fb-3663-459a-a6f7-35ca255c1365", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# For plotting, omit the SNR values\n", | ||
"class DataWrapper(SignalDataset):\n", | ||
" def __init__(self, dataset):\n", | ||
" self.dataset = dataset\n", | ||
" super().__init__(dataset)\n", | ||
"\n", | ||
" def __getitem__(self, idx):\n", | ||
" x, (y, _) = self.dataset[idx]\n", | ||
" return x, y\n", | ||
"\n", | ||
" def __len__(self) -> int:\n", | ||
" return len(self.dataset)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "84e05a27", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"plot_dataset = DataWrapper(sig53)\n", | ||
"\n", | ||
"data_loader = DataLoader(dataset=plot_dataset, batch_size=16, shuffle=True)\n", | ||
"\n", | ||
"\n", | ||
"# Transform the plotting titles from the class index to the name\n", | ||
"def target_idx_to_name(tensor: np.ndarray) -> List[str]:\n", | ||
" batch_size = tensor.shape[0]\n", | ||
" label = []\n", | ||
" for idx in range(batch_size):\n", | ||
" label.append(Sig53.convert_idx_to_name(int(tensor[idx])))\n", | ||
" return label\n", | ||
"\n", | ||
"\n", | ||
"visualizer = IQVisualizer(\n", | ||
" data_loader=data_loader,\n", | ||
" visualize_transform=None,\n", | ||
" visualize_target_transform=target_idx_to_name,\n", | ||
")\n", | ||
"\n", | ||
"for figure in iter(visualizer):\n", | ||
" figure.set_size_inches(14, 9)\n", | ||
" plt.show()\n", | ||
" break" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "16be4f03-fa82-4d29-9f08-fe547fd7053a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Repeat but plot the spectrograms for a new random sampling of the data\n", | ||
"visualizer = SpectrogramVisualizer(\n", | ||
" data_loader=data_loader,\n", | ||
" nfft=1024,\n", | ||
" visualize_transform=None,\n", | ||
" visualize_target_transform=target_idx_to_name,\n", | ||
")\n", | ||
"\n", | ||
"for figure in iter(visualizer):\n", | ||
" figure.set_size_inches(14, 9)\n", | ||
" plt.show()\n", | ||
" break" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0e8e793e-48f9-45a7-81a0-8276f61cc94a", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"----\n", | ||
"### Analyze Dataset\n", | ||
"The dataset can also be analyzed at the macro level for details such as the distribution of classes and SNR values. This exercise is performed below to show the nearly uniform distribution across each." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a988b188-fa07-4505-8f59-9bfab387243d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Loop through the dataset recording classes and SNRs\n", | ||
"class_counter_dict = {\n", | ||
" class_name: 0 for class_name in list(Sig53._idx_to_name_dict.values())\n", | ||
"}\n", | ||
"all_snrs = []\n", | ||
"\n", | ||
"for idx in tqdm(range(len(sig53))):\n", | ||
" data, (modulation, snr) = sig53[idx]\n", | ||
" class_counter_dict[Sig53.convert_idx_to_name(modulation)] += 1\n", | ||
" all_snrs.append(snr)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "020b5655-c6c4-4806-8b6a-dd027dbdb36f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Plot the distribution of classes\n", | ||
"class_names = list(class_counter_dict.keys())\n", | ||
"num_classes = list(class_counter_dict.values())\n", | ||
"\n", | ||
"plt.figure(figsize=(9, 9))\n", | ||
"plt.pie(num_classes, labels=class_names)\n", | ||
"plt.title(\"Class Distribution Pie Chart\")\n", | ||
"plt.show()\n", | ||
"\n", | ||
"plt.figure(figsize=(11, 4))\n", | ||
"plt.bar(class_names, num_classes)\n", | ||
"plt.xticks(rotation=90)\n", | ||
"plt.title(\"Class Distribution Bar Chart\")\n", | ||
"plt.xlabel(\"Modulation Class Name\")\n", | ||
"plt.ylabel(\"Counts\")\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c12ff742-cf0f-47f4-96ee-7ccdd147add2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Plot the distribution of SNR values\n", | ||
"plt.figure(figsize=(11, 4))\n", | ||
"plt.hist(x=all_snrs, bins=100)\n", | ||
"plt.title(\"SNR Distribution\")\n", | ||
"plt.xlabel(\"SNR Bins (dB)\")\n", | ||
"plt.ylabel(\"Counts\")\n", | ||
"plt.show()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.11" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.