From b169826f82b8f642c2766f7f2fc35b2be69cd9a8 Mon Sep 17 00:00:00 2001 From: Denis Kleyko Date: Sun, 12 Feb 2023 17:18:12 +0100 Subject: [PATCH 1/4] Add model with GLVQ classifier --- examples/UCI_benchmark_intRVFL_GLVQ.py | 303 +++++++++++++++++++++++++ 1 file changed, 303 insertions(+) create mode 100644 examples/UCI_benchmark_intRVFL_GLVQ.py diff --git a/examples/UCI_benchmark_intRVFL_GLVQ.py b/examples/UCI_benchmark_intRVFL_GLVQ.py new file mode 100644 index 00000000..87a12ee3 --- /dev/null +++ b/examples/UCI_benchmark_intRVFL_GLVQ.py @@ -0,0 +1,303 @@ +import torch +import torch.nn as nn +import torch.utils.data as data +from torch import Tensor +from tqdm import tqdm + +# Note: this example requires the torchmetrics library: https://torchmetrics.readthedocs.io +import torchmetrics +import torchhd +from torchhd.datasets import UCIClassificationBenchmark + +import copy +import warnings +import prototorch as pt +import pytorch_lightning as pl +from prototorch.models import GLVQ +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.optim.lr_scheduler import ExponentialLR +from torch.optim.lr_scheduler import LinearLR +from pytorch_lightning.callbacks import TQDMProgressBar +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=PossibleUserWarning) + + +# Function for performing min-max normalization of the input data samples +def create_min_max_normalize(min: Tensor, max: Tensor): + def normalize(input: Tensor) -> Tensor: + return torch.nan_to_num((input - min) / (max - min)) + + return normalize + + + +# Specify a model to be evaluated +class IntRVFLGLVQ(nn.Module): + """Class implementing integer random vector functional link network (intRVFL) model with Generalized Learning Vector Quantization (GLVQ) classifier as described in `Generalized Learning Vector Quantization for Classification in Randomized Neural Networks and Hyperdimensional Computing `_. + + Args: + dataset (torchhd.datasets.CollectionDataset): Specifies a dataset to be evaluted by the model. + num_feat (int): Number of features in the dataset. + max_epochs (int, optional): Number of epochs to train the classifier. + device (torch.device, optional): Specifies device to be used for Torch. + """ + + # These values of hyperparameters were found via the grid search for intRVFL model with GLVQ classifier as described in the paper. + INT_RVFL_GLVQ_HYPER = { + "abalone":(1450, 15, 15, 1), + "acute-inflammation":(50, 1, 5, 1), + "acute-nephritis":(50, 1, 2, 1), + "adult":(1150, 3, 15, 1), + "annealing":(1150, 7, 15, 1), + "arrhythmia":(1400, 7, 12, 1), + "audiology-std":(950, 3, 13, 1), + "balance-scale":(50, 7, 15, 1), + "balloons":(50, 1, 7, 1), + "bank":(200, 7, 14, 1), + "blood":(50, 7, 10, 1), + "breast-cancer":(50, 1, 13, 1), + "breast-cancer-wisc":(650, 3, 3, 1), + "breast-cancer-wisc-diag":(1500, 3, 15, 1), + "breast-cancer-wisc-prog":(1450, 3, 7, 1), + "breast-tissue":(1300, 1, 2, 1), + "car":(250, 3, 15, 1), + "cardiotocography-10clases":(1350, 3, 15, 1), + "cardiotocography-3clases":(900, 15, 15, 1), + "chess-krvk":(800, 1, 15, 1), + "chess-krvkp":(1350, 3, 14, 1), + "congressional-voting":(100, 15, 1, 1), + "conn-bench-sonar-mines-rocks":(1100, 3, 13, 1), + "conn-bench-vowel-deterding":(1350, 3, 15, 1), + "connect-4":(1100, 3, 1, 1), + "contrac":(50, 7, 15, 1), + "credit-approval":(200, 7, 13, 1), + "cylinder-bands":(1100, 7, 15, 1), + "dermatology":(900, 3, 15, 1), + "echocardiogram":(250, 15, 15, 1), + "ecoli":(350, 3, 9, 1), + "energy-y1":(650, 3, 15, 1), + "energy-y2":(1000, 7, 13, 1), + "fertility":(150, 7, 12, 1), + "flags":(900, 15, 10, 1), + "glass":(1400, 3, 13, 1), + "haberman-survival":(100, 3, 1, 1), + "hayes-roth":(50, 1, 8, 1), + "heart-cleveland":(50, 15, 4, 1), + "heart-hungarian":(50, 15, 3, 1), + "heart-switzerland":(50, 15, 4, 1), + "heart-va":(1350, 15, 6, 1), + "hepatitis":(1300, 1, 6, 1), + "hill-valley":(150, 1, 11, 1), + "horse-colic":(850, 1, 12, 1), + "ilpd-indian-liver":(1200, 7, 14, 1), + "image-segmentation":(650, 1, 15, 1), + "ionosphere":(1150, 1, 14, 1), + "iris":(50, 3, 1, 1), + "led-display":(50, 7, 1, 1), + "lenses":(50, 1, 5, 1), + "letter":(1500, 1, 15, 1), + "libras":(1250, 3, 13, 1), + "low-res-spect":(1400, 7, 14, 1), + "lung-cancer":(450, 1, 4, 1), + "lymphography":(1150, 1, 9, 1), + "magic":(800, 3, 15, 1), + "mammographic":(150, 7, 14, 1), + "miniboone":(650, 15, 15, 1), + "molec-biol-promoter":(1250, 1, 11, 1), + "molec-biol-splice":(1000, 15, 15, 1), + "monks-1":(50, 3, 8, 1), + "monks-2":(400, 1, 14, 1), + "monks-3":(50, 15, 12, 1), + "mushroom":(150, 3, 12, 1), + "musk-1":(1300, 7, 15, 1), + "musk-2":(1150, 7, 15, 1), + "nursery":(1000, 3, 15, 1), + "oocytes_merluccius_nucleus_4d":(1500, 7, 15, 1), + "oocytes_merluccius_states_2f":(1500, 7, 15, 1), + "oocytes_trisopterus_nucleus_2f":(1450, 3, 15, 1), + "oocytes_trisopterus_states_5b":(1450, 7, 15, 1), + "optical":(1100, 7, 15, 1), + "ozone":(50, 1, 1, 1), + "page-blocks":(800, 1, 15, 1), + "parkinsons":(1200, 1, 13, 1), + "pendigits":(1500, 1, 15, 1), + "pima":(50, 1, 13, 1), + "pittsburg-bridges-MATERIAL":(100, 1, 9, 1), + "pittsburg-bridges-REL-L":(1200, 1, 7, 1), + "pittsburg-bridges-SPAN":(450, 7, 7, 1), + "pittsburg-bridges-T-OR-D":(1000, 1, 8, 1), + "pittsburg-bridges-TYPE":(50, 7, 5, 1), + "planning":(50, 1, 1, 1), + "plant-margin":(1350, 7, 8, 1), + "plant-shape":(1450, 3, 14, 1), + "plant-texture":(1500, 7, 10, 1), + "post-operative":(50, 15, 1, 1), + "primary-tumor":(950, 3, 1, 1), + "ringnorm":(1500, 3, 14, 1), + "seeds":(550, 1, 12, 1), + "semeion":(1400, 15, 14, 1), + "soybean":(850, 3, 5, 1), + "spambase":(1350, 15, 15, 1), + "spect":(50, 1, 8, 1), + "spectf":(1100, 15, 15, 1), + "statlog-australian-credit":(200, 15, 1, 1), + "statlog-german-credit":(500, 15, 12, 1), + "statlog-heart":(50, 7, 14, 1), + "statlog-image":(950, 1, 15, 1), + "statlog-landsat":(1500, 3, 15, 1), + "statlog-shuttle":(100, 7, 15, 1), + "statlog-vehicle":(1450, 7, 14, 1), + "steel-plates":(1500, 3, 14, 1), + "synthetic-control":(1350, 3, 14, 1), + "teaching":(400, 3, 14, 1), + "thyroid":(300, 7, 15, 1), + "tic-tac-toe":(750, 1, 15, 1), + "titanic":(50, 1, 13, 1), + "trains":(100, 1, 7, 1), + "twonorm":(1100, 15, 14, 1), + "vertebral-column-2clases":(250, 3, 14, 1), + "vertebral-column-3clases":(200, 15, 14, 1), + "wall-following":(1200, 3, 15, 1), + "waveform":(1400, 7, 15, 1), + "waveform-noise":(1300, 15, 14, 1), + "wine":(850, 1, 15, 1), + "wine-quality-red":(1100, 1, 13, 1), + "wine-quality-white":(950, 3, 14, 1), + "yeast":(1350, 1, 5, 1), + "zoo":(400, 7, 1, 1), + } + + def __init__( + self, + dataset: torchhd.datasets.CollectionDataset, + num_feat: int, + max_epochs: int = 100, + device: torch.device = None, + ): + super(IntRVFLGLVQ, self).__init__() + self.device = device + self.num_feat = num_feat + self.max_epochs = max_epochs + # Fetch the hyperparameters for the corresponding dataset + hyper_param = self.INT_RVFL_GLVQ_HYPER[dataset.name] + # Dimensionality of vectors used when transforming input data + self.dimensions = hyper_param[0] + # Parameter of the clipping function used as the part of transforming input data + self.kappa = hyper_param[1] + self.transfer_beta = hyper_param[2] + self.prototypes_per_class = hyper_param[3] + # Number of classes in the dataset + self.num_classes = len(dataset.classes) + # Set up the encoding for the model as specified in "Density" + self.hypervector_encoding = torchhd.embeddings.Density( + self.num_feat, self.dimensions + ) + + # Specify encoding function for data samples + def encode(self, x): + return self.hypervector_encoding(x).clipping(self.kappa) + + # Specify how to make an inference step and issue a prediction + def forward(self, x): + # Make encodings for all data samples in the batch + encodings = self.encode(x) + # Classify with GLVQ classifier + predictions = self.classifier.predict(encodings) + return predictions + + # Train the classfier + def fit( + self, + train_ld: torch.utils.data.dataloader.DataLoader, + ): + # Hyperparameters for GLVQ + hparams = dict( + distribution={ + "num_classes": self.num_classes, + "per_class": self.prototypes_per_class, + }, + transfer_beta=self.transfer_beta, + lr=0.1, + ) + # Initialize the GLVQ classifier + self.classifier = GLVQ( + hparams, + optimizer=torch.optim.Adam, + prototypes_initializer=pt.initializers.SMCI(train_loader.dataset), + lr_scheduler=ExponentialLR, + lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), + ) + + # Setup trainer + trainer = pl.Trainer( + callbacks=[ + TQDMProgressBar(refresh_rate=3), + ], + max_epochs=self.max_epochs, + detect_anomaly=True, + ) + + # Training loop + trainer.fit(self.classifier, train_ld) + + +# Specify device to be used for Torch. +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print("Using {} device".format(device)) +# Specifies batch size to be used for the model. +batch_size = 200 +# Specifies how many random initializations of the model to evaluate for each dataset in the collection. +repeats = 5 + + +# Get an instance of the UCI benchmark +benchmark = UCIClassificationBenchmark("../data", download=True) +# Perform evaluation +for dataset in benchmark.datasets(): + print(dataset.name) + + #if dataset.name != 'Abalone': + # Number of features in the dataset. + num_feat = dataset.train[0][0].size(-1) + # Number of classes in the dataset. + num_classes = len(dataset.train.classes) + + # Get values for min-max normalization and add the transformation + min_val = torch.min(dataset.train.data, 0).values.to(device) + max_val = torch.max(dataset.train.data, 0).values.to(device) + transform = create_min_max_normalize(min_val, max_val) + dataset.test.transform = transform + + # Set up data loaders + test_loader = data.DataLoader(dataset.test, batch_size=batch_size) + + # Run for the requested number of simulations + for r in range(repeats): + # Creates a model to be evaluated. The model should specify both transformation of input data as weel as the algortihm for forming the classifier. + model = IntRVFLGLVQ( + getattr(torchhd.datasets, dataset.name), num_feat, device = device + ).to(device) + + # Replace raw data with hypervector encodings + train_ds = copy.deepcopy(dataset.train) + train_ds.data = torch.tensor(model.encode(transform(train_ds.data))) + + # Set up train loader + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True) + + + # Obtain the classifier for the model + model.fit(train_loader) + accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes) + + with torch.no_grad(): + for samples, targets in tqdm(test_loader, desc="Testing"): + samples = samples.to(device) + # Make prediction + predictions = model(samples) + accuracy.update(predictions.cpu(), targets) + benchmark.report(dataset, accuracy.compute().item()) + +# Returns a dictionary with names of the datasets and their respective accuracy that is averaged over folds (if applicable) and repeats +benchmark_accuracy = benchmark.score() +print(benchmark_accuracy) From 852d48cf682e7e9f658b51405b496cb070b96697 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sun, 12 Feb 2023 16:25:04 +0000 Subject: [PATCH 2/4] [github-action] formatting fixes --- examples/UCI_benchmark_intRVFL_GLVQ.py | 265 +++++++++++++------------ 1 file changed, 133 insertions(+), 132 deletions(-) diff --git a/examples/UCI_benchmark_intRVFL_GLVQ.py b/examples/UCI_benchmark_intRVFL_GLVQ.py index 87a12ee3..b39f532e 100644 --- a/examples/UCI_benchmark_intRVFL_GLVQ.py +++ b/examples/UCI_benchmark_intRVFL_GLVQ.py @@ -18,6 +18,7 @@ from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import LinearLR from pytorch_lightning.callbacks import TQDMProgressBar + warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=PossibleUserWarning) @@ -30,7 +31,6 @@ def normalize(input: Tensor) -> Tensor: return normalize - # Specify a model to be evaluated class IntRVFLGLVQ(nn.Module): """Class implementing integer random vector functional link network (intRVFL) model with Generalized Learning Vector Quantization (GLVQ) classifier as described in `Generalized Learning Vector Quantization for Classification in Randomized Neural Networks and Hyperdimensional Computing `_. @@ -44,127 +44,127 @@ class IntRVFLGLVQ(nn.Module): # These values of hyperparameters were found via the grid search for intRVFL model with GLVQ classifier as described in the paper. INT_RVFL_GLVQ_HYPER = { - "abalone":(1450, 15, 15, 1), - "acute-inflammation":(50, 1, 5, 1), - "acute-nephritis":(50, 1, 2, 1), - "adult":(1150, 3, 15, 1), - "annealing":(1150, 7, 15, 1), - "arrhythmia":(1400, 7, 12, 1), - "audiology-std":(950, 3, 13, 1), - "balance-scale":(50, 7, 15, 1), - "balloons":(50, 1, 7, 1), - "bank":(200, 7, 14, 1), - "blood":(50, 7, 10, 1), - "breast-cancer":(50, 1, 13, 1), - "breast-cancer-wisc":(650, 3, 3, 1), - "breast-cancer-wisc-diag":(1500, 3, 15, 1), - "breast-cancer-wisc-prog":(1450, 3, 7, 1), - "breast-tissue":(1300, 1, 2, 1), - "car":(250, 3, 15, 1), - "cardiotocography-10clases":(1350, 3, 15, 1), - "cardiotocography-3clases":(900, 15, 15, 1), - "chess-krvk":(800, 1, 15, 1), - "chess-krvkp":(1350, 3, 14, 1), - "congressional-voting":(100, 15, 1, 1), - "conn-bench-sonar-mines-rocks":(1100, 3, 13, 1), - "conn-bench-vowel-deterding":(1350, 3, 15, 1), - "connect-4":(1100, 3, 1, 1), - "contrac":(50, 7, 15, 1), - "credit-approval":(200, 7, 13, 1), - "cylinder-bands":(1100, 7, 15, 1), - "dermatology":(900, 3, 15, 1), - "echocardiogram":(250, 15, 15, 1), - "ecoli":(350, 3, 9, 1), - "energy-y1":(650, 3, 15, 1), - "energy-y2":(1000, 7, 13, 1), - "fertility":(150, 7, 12, 1), - "flags":(900, 15, 10, 1), - "glass":(1400, 3, 13, 1), - "haberman-survival":(100, 3, 1, 1), - "hayes-roth":(50, 1, 8, 1), - "heart-cleveland":(50, 15, 4, 1), - "heart-hungarian":(50, 15, 3, 1), - "heart-switzerland":(50, 15, 4, 1), - "heart-va":(1350, 15, 6, 1), - "hepatitis":(1300, 1, 6, 1), - "hill-valley":(150, 1, 11, 1), - "horse-colic":(850, 1, 12, 1), - "ilpd-indian-liver":(1200, 7, 14, 1), - "image-segmentation":(650, 1, 15, 1), - "ionosphere":(1150, 1, 14, 1), - "iris":(50, 3, 1, 1), - "led-display":(50, 7, 1, 1), - "lenses":(50, 1, 5, 1), - "letter":(1500, 1, 15, 1), - "libras":(1250, 3, 13, 1), - "low-res-spect":(1400, 7, 14, 1), - "lung-cancer":(450, 1, 4, 1), - "lymphography":(1150, 1, 9, 1), - "magic":(800, 3, 15, 1), - "mammographic":(150, 7, 14, 1), - "miniboone":(650, 15, 15, 1), - "molec-biol-promoter":(1250, 1, 11, 1), - "molec-biol-splice":(1000, 15, 15, 1), - "monks-1":(50, 3, 8, 1), - "monks-2":(400, 1, 14, 1), - "monks-3":(50, 15, 12, 1), - "mushroom":(150, 3, 12, 1), - "musk-1":(1300, 7, 15, 1), - "musk-2":(1150, 7, 15, 1), - "nursery":(1000, 3, 15, 1), - "oocytes_merluccius_nucleus_4d":(1500, 7, 15, 1), - "oocytes_merluccius_states_2f":(1500, 7, 15, 1), - "oocytes_trisopterus_nucleus_2f":(1450, 3, 15, 1), - "oocytes_trisopterus_states_5b":(1450, 7, 15, 1), - "optical":(1100, 7, 15, 1), - "ozone":(50, 1, 1, 1), - "page-blocks":(800, 1, 15, 1), - "parkinsons":(1200, 1, 13, 1), - "pendigits":(1500, 1, 15, 1), - "pima":(50, 1, 13, 1), - "pittsburg-bridges-MATERIAL":(100, 1, 9, 1), - "pittsburg-bridges-REL-L":(1200, 1, 7, 1), - "pittsburg-bridges-SPAN":(450, 7, 7, 1), - "pittsburg-bridges-T-OR-D":(1000, 1, 8, 1), - "pittsburg-bridges-TYPE":(50, 7, 5, 1), - "planning":(50, 1, 1, 1), - "plant-margin":(1350, 7, 8, 1), - "plant-shape":(1450, 3, 14, 1), - "plant-texture":(1500, 7, 10, 1), - "post-operative":(50, 15, 1, 1), - "primary-tumor":(950, 3, 1, 1), - "ringnorm":(1500, 3, 14, 1), - "seeds":(550, 1, 12, 1), - "semeion":(1400, 15, 14, 1), - "soybean":(850, 3, 5, 1), - "spambase":(1350, 15, 15, 1), - "spect":(50, 1, 8, 1), - "spectf":(1100, 15, 15, 1), - "statlog-australian-credit":(200, 15, 1, 1), - "statlog-german-credit":(500, 15, 12, 1), - "statlog-heart":(50, 7, 14, 1), - "statlog-image":(950, 1, 15, 1), - "statlog-landsat":(1500, 3, 15, 1), - "statlog-shuttle":(100, 7, 15, 1), - "statlog-vehicle":(1450, 7, 14, 1), - "steel-plates":(1500, 3, 14, 1), - "synthetic-control":(1350, 3, 14, 1), - "teaching":(400, 3, 14, 1), - "thyroid":(300, 7, 15, 1), - "tic-tac-toe":(750, 1, 15, 1), - "titanic":(50, 1, 13, 1), - "trains":(100, 1, 7, 1), - "twonorm":(1100, 15, 14, 1), - "vertebral-column-2clases":(250, 3, 14, 1), - "vertebral-column-3clases":(200, 15, 14, 1), - "wall-following":(1200, 3, 15, 1), - "waveform":(1400, 7, 15, 1), - "waveform-noise":(1300, 15, 14, 1), - "wine":(850, 1, 15, 1), - "wine-quality-red":(1100, 1, 13, 1), - "wine-quality-white":(950, 3, 14, 1), - "yeast":(1350, 1, 5, 1), - "zoo":(400, 7, 1, 1), + "abalone": (1450, 15, 15, 1), + "acute-inflammation": (50, 1, 5, 1), + "acute-nephritis": (50, 1, 2, 1), + "adult": (1150, 3, 15, 1), + "annealing": (1150, 7, 15, 1), + "arrhythmia": (1400, 7, 12, 1), + "audiology-std": (950, 3, 13, 1), + "balance-scale": (50, 7, 15, 1), + "balloons": (50, 1, 7, 1), + "bank": (200, 7, 14, 1), + "blood": (50, 7, 10, 1), + "breast-cancer": (50, 1, 13, 1), + "breast-cancer-wisc": (650, 3, 3, 1), + "breast-cancer-wisc-diag": (1500, 3, 15, 1), + "breast-cancer-wisc-prog": (1450, 3, 7, 1), + "breast-tissue": (1300, 1, 2, 1), + "car": (250, 3, 15, 1), + "cardiotocography-10clases": (1350, 3, 15, 1), + "cardiotocography-3clases": (900, 15, 15, 1), + "chess-krvk": (800, 1, 15, 1), + "chess-krvkp": (1350, 3, 14, 1), + "congressional-voting": (100, 15, 1, 1), + "conn-bench-sonar-mines-rocks": (1100, 3, 13, 1), + "conn-bench-vowel-deterding": (1350, 3, 15, 1), + "connect-4": (1100, 3, 1, 1), + "contrac": (50, 7, 15, 1), + "credit-approval": (200, 7, 13, 1), + "cylinder-bands": (1100, 7, 15, 1), + "dermatology": (900, 3, 15, 1), + "echocardiogram": (250, 15, 15, 1), + "ecoli": (350, 3, 9, 1), + "energy-y1": (650, 3, 15, 1), + "energy-y2": (1000, 7, 13, 1), + "fertility": (150, 7, 12, 1), + "flags": (900, 15, 10, 1), + "glass": (1400, 3, 13, 1), + "haberman-survival": (100, 3, 1, 1), + "hayes-roth": (50, 1, 8, 1), + "heart-cleveland": (50, 15, 4, 1), + "heart-hungarian": (50, 15, 3, 1), + "heart-switzerland": (50, 15, 4, 1), + "heart-va": (1350, 15, 6, 1), + "hepatitis": (1300, 1, 6, 1), + "hill-valley": (150, 1, 11, 1), + "horse-colic": (850, 1, 12, 1), + "ilpd-indian-liver": (1200, 7, 14, 1), + "image-segmentation": (650, 1, 15, 1), + "ionosphere": (1150, 1, 14, 1), + "iris": (50, 3, 1, 1), + "led-display": (50, 7, 1, 1), + "lenses": (50, 1, 5, 1), + "letter": (1500, 1, 15, 1), + "libras": (1250, 3, 13, 1), + "low-res-spect": (1400, 7, 14, 1), + "lung-cancer": (450, 1, 4, 1), + "lymphography": (1150, 1, 9, 1), + "magic": (800, 3, 15, 1), + "mammographic": (150, 7, 14, 1), + "miniboone": (650, 15, 15, 1), + "molec-biol-promoter": (1250, 1, 11, 1), + "molec-biol-splice": (1000, 15, 15, 1), + "monks-1": (50, 3, 8, 1), + "monks-2": (400, 1, 14, 1), + "monks-3": (50, 15, 12, 1), + "mushroom": (150, 3, 12, 1), + "musk-1": (1300, 7, 15, 1), + "musk-2": (1150, 7, 15, 1), + "nursery": (1000, 3, 15, 1), + "oocytes_merluccius_nucleus_4d": (1500, 7, 15, 1), + "oocytes_merluccius_states_2f": (1500, 7, 15, 1), + "oocytes_trisopterus_nucleus_2f": (1450, 3, 15, 1), + "oocytes_trisopterus_states_5b": (1450, 7, 15, 1), + "optical": (1100, 7, 15, 1), + "ozone": (50, 1, 1, 1), + "page-blocks": (800, 1, 15, 1), + "parkinsons": (1200, 1, 13, 1), + "pendigits": (1500, 1, 15, 1), + "pima": (50, 1, 13, 1), + "pittsburg-bridges-MATERIAL": (100, 1, 9, 1), + "pittsburg-bridges-REL-L": (1200, 1, 7, 1), + "pittsburg-bridges-SPAN": (450, 7, 7, 1), + "pittsburg-bridges-T-OR-D": (1000, 1, 8, 1), + "pittsburg-bridges-TYPE": (50, 7, 5, 1), + "planning": (50, 1, 1, 1), + "plant-margin": (1350, 7, 8, 1), + "plant-shape": (1450, 3, 14, 1), + "plant-texture": (1500, 7, 10, 1), + "post-operative": (50, 15, 1, 1), + "primary-tumor": (950, 3, 1, 1), + "ringnorm": (1500, 3, 14, 1), + "seeds": (550, 1, 12, 1), + "semeion": (1400, 15, 14, 1), + "soybean": (850, 3, 5, 1), + "spambase": (1350, 15, 15, 1), + "spect": (50, 1, 8, 1), + "spectf": (1100, 15, 15, 1), + "statlog-australian-credit": (200, 15, 1, 1), + "statlog-german-credit": (500, 15, 12, 1), + "statlog-heart": (50, 7, 14, 1), + "statlog-image": (950, 1, 15, 1), + "statlog-landsat": (1500, 3, 15, 1), + "statlog-shuttle": (100, 7, 15, 1), + "statlog-vehicle": (1450, 7, 14, 1), + "steel-plates": (1500, 3, 14, 1), + "synthetic-control": (1350, 3, 14, 1), + "teaching": (400, 3, 14, 1), + "thyroid": (300, 7, 15, 1), + "tic-tac-toe": (750, 1, 15, 1), + "titanic": (50, 1, 13, 1), + "trains": (100, 1, 7, 1), + "twonorm": (1100, 15, 14, 1), + "vertebral-column-2clases": (250, 3, 14, 1), + "vertebral-column-3clases": (200, 15, 14, 1), + "wall-following": (1200, 3, 15, 1), + "waveform": (1400, 7, 15, 1), + "waveform-noise": (1300, 15, 14, 1), + "wine": (850, 1, 15, 1), + "wine-quality-red": (1100, 1, 13, 1), + "wine-quality-white": (950, 3, 14, 1), + "yeast": (1350, 1, 5, 1), + "zoo": (400, 7, 1, 1), } def __init__( @@ -185,7 +185,7 @@ def __init__( # Parameter of the clipping function used as the part of transforming input data self.kappa = hyper_param[1] self.transfer_beta = hyper_param[2] - self.prototypes_per_class = hyper_param[3] + self.prototypes_per_class = hyper_param[3] # Number of classes in the dataset self.num_classes = len(dataset.classes) # Set up the encoding for the model as specified in "Density" @@ -215,18 +215,18 @@ def fit( distribution={ "num_classes": self.num_classes, "per_class": self.prototypes_per_class, - }, + }, transfer_beta=self.transfer_beta, lr=0.1, ) # Initialize the GLVQ classifier self.classifier = GLVQ( hparams, - optimizer=torch.optim.Adam, + optimizer=torch.optim.Adam, prototypes_initializer=pt.initializers.SMCI(train_loader.dataset), lr_scheduler=ExponentialLR, lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), - ) + ) # Setup trainer trainer = pl.Trainer( @@ -236,7 +236,7 @@ def fit( max_epochs=self.max_epochs, detect_anomaly=True, ) - + # Training loop trainer.fit(self.classifier, train_ld) @@ -256,7 +256,7 @@ def fit( for dataset in benchmark.datasets(): print(dataset.name) - #if dataset.name != 'Abalone': + # if dataset.name != 'Abalone': # Number of features in the dataset. num_feat = dataset.train[0][0].size(-1) # Number of classes in the dataset. @@ -275,16 +275,17 @@ def fit( for r in range(repeats): # Creates a model to be evaluated. The model should specify both transformation of input data as weel as the algortihm for forming the classifier. model = IntRVFLGLVQ( - getattr(torchhd.datasets, dataset.name), num_feat, device = device + getattr(torchhd.datasets, dataset.name), num_feat, device=device ).to(device) # Replace raw data with hypervector encodings train_ds = copy.deepcopy(dataset.train) train_ds.data = torch.tensor(model.encode(transform(train_ds.data))) - - # Set up train loader - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True) + # Set up train loader + train_loader = torch.utils.data.DataLoader( + train_ds, batch_size=batch_size, shuffle=True + ) # Obtain the classifier for the model model.fit(train_loader) From 393dcacf1f9842d84cee86dca2c130fb667e0456 Mon Sep 17 00:00:00 2001 From: mikeheddes Date: Mon, 13 Feb 2023 10:44:56 -0800 Subject: [PATCH 3/4] Improve backward compatibility with torchmetrics --- examples/UCI_benchmark_intRVFL.py | 5 ++++- examples/UCI_benchmark_intRVFL_GLVQ.py | 17 +++++++++++------ examples/emg_hand_gestures.py | 4 +++- examples/graphhd.py | 4 +++- examples/language_recognition.py | 2 +- examples/mnist.py | 2 +- examples/mnist_hugging_face.py | 2 +- examples/mnist_nonlinear.py | 2 +- examples/voicehd.py | 2 +- 9 files changed, 26 insertions(+), 14 deletions(-) diff --git a/examples/UCI_benchmark_intRVFL.py b/examples/UCI_benchmark_intRVFL.py index 82ba3539..b84a9c13 100644 --- a/examples/UCI_benchmark_intRVFL.py +++ b/examples/UCI_benchmark_intRVFL.py @@ -308,7 +308,9 @@ def fit( # Obtain the classifier for the model model.fit(train_loader) - accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes) + accuracy = torchmetrics.Accuracy( + task="multiclass", top_k=1, num_classes=num_classes + ) with torch.no_grad(): for samples, targets in tqdm(test_loader, desc="Testing"): @@ -317,6 +319,7 @@ def fit( predictions = model(samples) accuracy.update(predictions.cpu(), targets) + print(f"Accuracy: {(accuracy.compute().item() * 100):.3f}%") benchmark.report(dataset, accuracy.compute().item()) # Returns a dictionary with names of the datasets and their respective accuracy that is averaged over folds (if applicable) and repeats diff --git a/examples/UCI_benchmark_intRVFL_GLVQ.py b/examples/UCI_benchmark_intRVFL_GLVQ.py index b39f532e..fba53b7c 100644 --- a/examples/UCI_benchmark_intRVFL_GLVQ.py +++ b/examples/UCI_benchmark_intRVFL_GLVQ.py @@ -1,3 +1,5 @@ +import copy +import warnings import torch import torch.nn as nn import torch.utils.data as data @@ -9,14 +11,14 @@ import torchhd from torchhd.datasets import UCIClassificationBenchmark -import copy -import warnings +# Note: this example requires the prototorch library: https://github.com/si-cim/prototorch import prototorch as pt -import pytorch_lightning as pl +# Note: this example requires the prototorch-models library: https://github.com/si-cim/prototorch_models from prototorch.models import GLVQ +# Note: this example requires the pytorch-lightning library: https://www.pytorchlightning.ai +import pytorch_lightning as pl from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.optim.lr_scheduler import ExponentialLR -from torch.optim.lr_scheduler import LinearLR from pytorch_lightning.callbacks import TQDMProgressBar warnings.filterwarnings("ignore", category=UserWarning) @@ -256,7 +258,6 @@ def fit( for dataset in benchmark.datasets(): print(dataset.name) - # if dataset.name != 'Abalone': # Number of features in the dataset. num_feat = dataset.train[0][0].size(-1) # Number of classes in the dataset. @@ -289,7 +290,9 @@ def fit( # Obtain the classifier for the model model.fit(train_loader) - accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes) + accuracy = torchmetrics.Accuracy( + task="multiclass", top_k=1, num_classes=num_classes + ) with torch.no_grad(): for samples, targets in tqdm(test_loader, desc="Testing"): @@ -297,6 +300,8 @@ def fit( # Make prediction predictions = model(samples) accuracy.update(predictions.cpu(), targets) + + print(f"Accuracy: {(accuracy.compute().item() * 100):.3f}%") benchmark.report(dataset, accuracy.compute().item()) # Returns a dictionary with names of the datasets and their respective accuracy that is averaged over folds (if applicable) and repeats diff --git a/examples/emg_hand_gestures.py b/examples/emg_hand_gestures.py index 74b9b1e0..85cb0792 100644 --- a/examples/emg_hand_gestures.py +++ b/examples/emg_hand_gestures.py @@ -74,7 +74,9 @@ def experiment(subjects=[0]): sample_hv = encode(samples) model.add(sample_hv, targets) - accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes) + accuracy = torchmetrics.Accuracy( + task="multiclass", top_k=1, num_classes=num_classes + ) with torch.no_grad(): model.normalize() diff --git a/examples/graphhd.py b/examples/graphhd.py index fa86bd08..60a5d01a 100644 --- a/examples/graphhd.py +++ b/examples/graphhd.py @@ -115,7 +115,9 @@ def forward(self, x): samples_hv = encode(samples).unsqueeze(0) model.add(samples_hv, samples.y) -accuracy = torchmetrics.Accuracy("multiclass", num_classes=graphs.num_classes) +accuracy = torchmetrics.Accuracy( + task="multiclass", top_k=1, num_classes=graphs.num_classes +) with torch.no_grad(): model.normalize() diff --git a/examples/language_recognition.py b/examples/language_recognition.py index fd5da321..352575a0 100644 --- a/examples/language_recognition.py +++ b/examples/language_recognition.py @@ -81,7 +81,7 @@ def forward(self, x): samples_hv = encode(samples) model.add(samples_hv, labels) -accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes) +accuracy = torchmetrics.Accuracy(task="multiclass", top_k=1, num_classes=num_classes) with torch.no_grad(): model.normalize() diff --git a/examples/mnist.py b/examples/mnist.py index 79bac5a9..2b3ddb3f 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -59,7 +59,7 @@ def forward(self, x): samples_hv = encode(samples) model.add(samples_hv, labels) -accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes) +accuracy = torchmetrics.Accuracy(task="multiclass", top_k=1, num_classes=num_classes) with torch.no_grad(): model.normalize() diff --git a/examples/mnist_hugging_face.py b/examples/mnist_hugging_face.py index 9a6a6b8d..f0a9f459 100644 --- a/examples/mnist_hugging_face.py +++ b/examples/mnist_hugging_face.py @@ -72,7 +72,7 @@ def forward(self, x, dot=False): samples_hv = model.encode(samples) model.classify.add(samples_hv, labels) -accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes) +accuracy = torchmetrics.Accuracy(task="multiclass", top_k=1, num_classes=num_classes) with torch.no_grad(): model.classify.normalize() diff --git a/examples/mnist_nonlinear.py b/examples/mnist_nonlinear.py index 493cc6c2..ce9a4c9c 100644 --- a/examples/mnist_nonlinear.py +++ b/examples/mnist_nonlinear.py @@ -57,7 +57,7 @@ def forward(self, x): samples_hv = encode(samples) model.add(samples_hv, labels) -accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes) +accuracy = torchmetrics.Accuracy(task="multiclass", top_k=1, num_classes=num_classes) with torch.no_grad(): model.normalize() diff --git a/examples/voicehd.py b/examples/voicehd.py index 39a6e075..77ed94ee 100644 --- a/examples/voicehd.py +++ b/examples/voicehd.py @@ -52,7 +52,7 @@ def forward(self, x): samples_hv = encode(samples) model.add(samples_hv, labels) -accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes) +accuracy = torchmetrics.Accuracy(task="multiclass", top_k=1, num_classes=num_classes) with torch.no_grad(): model.normalize() From 9c24fbf5f624c254375ef9e50e6c0fff84de1d20 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 13 Feb 2023 18:45:26 +0000 Subject: [PATCH 4/4] [github-action] formatting fixes --- examples/UCI_benchmark_intRVFL_GLVQ.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/UCI_benchmark_intRVFL_GLVQ.py b/examples/UCI_benchmark_intRVFL_GLVQ.py index fba53b7c..c70301d5 100644 --- a/examples/UCI_benchmark_intRVFL_GLVQ.py +++ b/examples/UCI_benchmark_intRVFL_GLVQ.py @@ -13,8 +13,10 @@ # Note: this example requires the prototorch library: https://github.com/si-cim/prototorch import prototorch as pt + # Note: this example requires the prototorch-models library: https://github.com/si-cim/prototorch_models from prototorch.models import GLVQ + # Note: this example requires the pytorch-lightning library: https://www.pytorchlightning.ai import pytorch_lightning as pl from pytorch_lightning.utilities.warnings import PossibleUserWarning @@ -300,7 +302,7 @@ def fit( # Make prediction predictions = model(samples) accuracy.update(predictions.cpu(), targets) - + print(f"Accuracy: {(accuracy.compute().item() * 100):.3f}%") benchmark.report(dataset, accuracy.compute().item())