diff --git a/torchxrayvision/datasets.py b/torchxrayvision/datasets.py index 7340c02..31b3863 100644 --- a/torchxrayvision/datasets.py +++ b/torchxrayvision/datasets.py @@ -1094,7 +1094,8 @@ def __getitem__(self, idx): sample["lab"] = self.labels[idx] imgid = self.csv['Path'].iloc[idx] - imgid = imgid.replace("CheXpert-v1.0-small/", "") + #clean up path in csv so the user can specify the path + imgid = imgid.replace("CheXpert-v1.0-small/", "").replace("CheXpert-v1.0/", "") img_path = os.path.join(self.imgpath, imgid) img = imread(img_path)