diff --git a/gluefactory/datasets/hpatches.py b/gluefactory/datasets/hpatches.py index cf4c7993..8ec93eae 100644 --- a/gluefactory/datasets/hpatches.py +++ b/gluefactory/datasets/hpatches.py @@ -4,7 +4,7 @@ import argparse import logging -import tarfile +import zipfile import matplotlib.pyplot as plt import numpy as np @@ -54,7 +54,7 @@ class HPatches(BaseDataset, torch.utils.data.Dataset): "v_astronautis", "v_talent", ) - url = "http://icvl.ee.ic.ac.uk/vbalnt/hpatches/hpatches-sequences-release.tar.gz" + url = "https://www.kaggle.com/api/v1/datasets/download/javidtheimmortal/hpatches-sequence-release" def _init(self, conf): assert conf.batch_size == 1 @@ -79,11 +79,12 @@ def _init(self, conf): def download(self): data_dir = self.root.parent data_dir.mkdir(exist_ok=True, parents=True) - tar_path = data_dir / self.url.rsplit("/", 1)[-1] - torch.hub.download_url_to_file(self.url, tar_path) - with tarfile.open(tar_path) as tar: - tar.extractall(data_dir) - tar_path.unlink() + zip_path = data_dir / self.url.rsplit("/", 1)[-1] + torch.hub.download_url_to_file(self.url, zip_path) + # Open the ZIP file + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + # Extract all contents to a directory + zip_ref.extractall(data_dir) def get_dataset(self, split): assert split in ["val", "test"]