diff --git a/backbones.py b/backbones.py index 4a5c43e..66620cd 100644 --- a/backbones.py +++ b/backbones.py @@ -5,7 +5,7 @@ from functools import partial class ConvBN2d(nn.Module): - def __init__(self, in_f, out_f, kernel_size = 3, stride = 1, padding = 1, groups = 1, outRelu = False, leaky = args.leaky): + def __init__(self, in_f, out_f, kernel_size = 1, stride = 1, padding = 0, groups = 1, outRelu = False, leaky = args.leaky): super(ConvBN2d, self).__init__() self.conv = nn.Conv2d(in_f, out_f, kernel_size = kernel_size, stride = stride, padding = padding, groups = groups, bias = False) self.bn = nn.BatchNorm2d(out_f) @@ -20,6 +20,7 @@ def forward(self, x, lbda = None, perm = None): y = self.bn(self.conv(x)) if lbda is not None: y = lbda * y + (1 - lbda) * y[perm] + if self.outRelu: if not self.leaky: return torch.relu(y) @@ -31,13 +32,15 @@ def forward(self, x, lbda = None, perm = None): class BasicBlock(nn.Module): def __init__(self, in_f, out_f, stride=1, in_expansion = None): super(BasicBlock, self).__init__() - self.convbn1 = ConvBN2d(in_f, out_f, stride = stride, outRelu = True) - self.convbn2 = ConvBN2d(out_f, out_f) + self.convbn1 = ConvBN2d(in_f, 6*stride*in_f) + self.convbn2 = ConvBN2d(6*stride*in_f, 6*stride*in_f, stride = stride, outRelu = True, kernel_size = 7, padding = 3) + self.convbn3 = ConvBN2d(6*stride*in_f, out_f) self.shortcut = None if stride == 1 and in_f == out_f else ConvBN2d(in_f, out_f, kernel_size = 1, stride = stride, padding = 0) def forward(self, x, lbda = None, perm = None): - y = self.convbn1(x) - z = self.convbn2(y) + y1 = self.convbn1(x) + y2 = self.convbn2(y1) + z = self.convbn3(y2) if self.shortcut is not None: z += self.shortcut(x) else: @@ -79,16 +82,16 @@ def __init__(self, block, blockList, featureMaps, large = False): super(ResNet, self).__init__() self.large = large if not large: - self.embed = ConvBN2d(3, featureMaps, outRelu = True) + self.embed = ConvBN2d(3, featureMaps, kernel_size=7, outRelu = True) else: self.embed = ConvBN2d(3, featureMaps, kernel_size=7, stride=2, padding=3, outRelu = True) - self.mp = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1) + self.mp = nn.MaxPool2d(kernel_size = 7, stride = 2, padding = 1) blocks = [] lastMult = 1 first = True for (nBlocks, stride, multiplier) in blockList: for i in range(nBlocks): - blocks.append(block(int(featureMaps * lastMult), int(featureMaps * multiplier), in_expansion = 1 if first else 4, stride = 1 if i > 0 else stride)) + blocks.append(block(int(featureMaps * lastMult), int(featureMaps * multiplier), in_expansion = 1 if first else 4, stride = 1 if i != 0 else stride)) first = False lastMult = multiplier self.blocks = nn.ModuleList(blocks) @@ -104,7 +107,6 @@ def forward(self, x, mixup = None, lbda = None, perm = None): x = lbda * x + (1 - lbda) * x[perm] if x.shape[1] == 1: x = x.repeat(1,3,1,1) - if mixup_layer == 1: y = self.embed(x, lbda, perm) else: @@ -112,13 +114,11 @@ def forward(self, x, mixup = None, lbda = None, perm = None): if self.large: y = self.mp(y) - for i, block in enumerate(self.blocks): if mixup_layer == i + 2: y = block(y, lbda, perm) else: y = block(y) - y = y.mean(dim = list(range(2, len(y.shape)))) return y @@ -213,6 +213,7 @@ def prepareBackbone(): backbone = '_'.join(backbone.split('_')[:-1]) return { + "resnet_ofa": lambda: (ResNet(BasicBlock, [(4, 1, 1), (4, 2, 2),(4,2,8),(4,2,16)], args.feature_maps, large = False), 16 * args.feature_maps), "resnet18": lambda: (ResNet(BasicBlock, [(2, 1, 1), (2, 2, 2), (2, 2, 4), (2, 2, 8)], args.feature_maps, large = large), 8 * args.feature_maps), "resnet20": lambda: (ResNet(BasicBlock, [(3, 1, 1), (3, 2, 2), (3, 2, 4)], args.feature_maps, large = large), 4 * args.feature_maps), "resnet56": lambda: (ResNet(BasicBlock, [(9, 1, 1), (9, 2, 2), (9, 2, 4)], args.feature_maps, large = large), 4 * args.feature_maps), diff --git a/dataloaders.py b/dataloaders.py index eb2a6e8..3c6681d 100644 --- a/dataloaders.py +++ b/dataloaders.py @@ -30,9 +30,10 @@ def __init__(self, data, targets, transforms, target_transforms=lambda x:x, open self.transforms = transforms self.target_transforms = target_transforms self.opener = opener + def __getitem__(self, idx): if isinstance(self.data[idx], str): - elt = self.opener(args.dataset_path + self.data[idx]) + elt = self.opener("/users2/local/datasets/" + self.data[idx]) else: elt = self.data[idx] return self.transforms(elt), self.target_transforms(self.targets[idx]) @@ -206,10 +207,10 @@ def metadataset(datasetName, name): """ Generic function to load a dataset from the Meta-Dataset v1.0 """ - f = open(args.dataset_path + "datasets.json") + f = open("/users2/local/datasets/" + "datasets.json") all_datasets = json.loads(f.read()) f.close() - dataset = all_datasets[name+"_" + datasetName] + dataset = all_datasets[name + "_" + datasetName] if datasetName == "train": image_size = args.training_image_size if args.training_image_size>0 else 126 else: @@ -220,7 +221,7 @@ def metadataset(datasetName, name): else: default_test_transforms = ['metadatasettotensor', 'randomresizedcrop', 'biresize', 'metadatasetnorm'] trans = get_transforms(image_size, datasetName, default_train_transforms, default_test_transforms) - return {"dataloader": dataLoader(DataHolder(dataset["data"], dataset["targets"], trans), shuffle = datasetName == "train", episodic=args.episodic and datasetName == "train", datasetName=name+"_"+datasetName), "name":dataset["name"], "num_classes":dataset["num_classes"], "name_classes": dataset["name_classes"]} + return {"dataloader": dataLoader(DataHolder(dataset["data"], dataset["targets"], trans), shuffle = True, episodic=args.episodic and datasetName == "train", datasetName=name+"_"+datasetName), "name":dataset["name"], "num_classes":dataset["num_classes"], "name_classes": dataset["name_classes"]} def metadataset_imagenet_v2(): f = open(args.dataset_path + "datasets.json") diff --git a/few_shot_evaluation.py b/few_shot_evaluation.py index f96dd34..3633cd0 100644 --- a/few_shot_evaluation.py +++ b/few_shot_evaluation.py @@ -188,7 +188,7 @@ def __init__(self, **kwargs): self.graph_map = {node['wn_id']:node['children_ids'] for node in self.graph} self.node_candidates = [node for node in self.graph_map.keys() if 5<=len(self.get_spanning_leaves(node))<=392] - self.classIdx = self.dataset["classIdx"] + self.classIdx = {'n02085620': 0, 'n02085782': 1, 'n02085936': 2, 'n02086079': 3, 'n02086240': 4, 'n02086646': 5, 'n02086910': 6, 'n02087046': 7, 'n02087394': 8, 'n02088094': 9, 'n02088238': 10, 'n02088364': 11, 'n02088466': 12, 'n02088632': 13, 'n02089078': 14, 'n02089867': 15, 'n02089973': 16, 'n02090379': 17, 'n02090622': 18, 'n02090721': 19, 'n02091032': 20, 'n02091134': 21, 'n02091244': 22, 'n02091467': 23, 'n02091635': 24, 'n02091831': 25, 'n02092002': 26, 'n02092339': 27, 'n02093256': 28, 'n02093428': 29, 'n02093647': 30, 'n02093754': 31, 'n02093859': 32, 'n02093991': 33, 'n02094114': 34, 'n02094258': 35, 'n02094433': 36, 'n02095314': 37, 'n02095570': 38, 'n02095889': 39, 'n02096051': 40, 'n02096177': 41, 'n02096294': 42, 'n02096437': 43, 'n02096585': 44, 'n02097047': 45, 'n02097130': 46, 'n02097209': 47, 'n02097298': 48, 'n02097474': 49, 'n02097658': 50, 'n02098105': 51, 'n02098286': 52, 'n02098413': 53, 'n02099267': 54, 'n02099429': 55, 'n02099601': 56, 'n02099712': 57, 'n02099849': 58, 'n02100236': 59, 'n02100583': 60, 'n02100735': 61, 'n02100877': 62, 'n02101006': 63, 'n02101388': 64, 'n02101556': 65, 'n02102040': 66, 'n02102177': 67, 'n02102318': 68, 'n02102480': 69, 'n02102973': 70, 'n02104029': 71, 'n02104365': 72, 'n02105056': 73, 'n02105162': 74, 'n02105251': 75, 'n02105412': 76, 'n02105505': 77, 'n02105641': 78, 'n02105855': 79, 'n02106030': 80, 'n02106166': 81, 'n02106382': 82, 'n02106550': 83, 'n02106662': 84, 'n02107142': 85, 'n02107312': 86, 'n02107574': 87, 'n02107683': 88, 'n02107908': 89, 'n02108000': 90, 'n02108089': 91, 'n02108422': 92, 'n02108551': 93, 'n02108915': 94, 'n02109047': 95, 'n02109525': 96, 'n02109961': 97, 'n02110063': 98, 'n02110185': 99, 'n02110341': 100, 'n02110627': 101, 'n02110806': 102, 'n02110958': 103, 'n02111129': 104, 'n02111277': 105, 'n02111500': 106, 'n02111889': 107, 'n02112018': 108, 'n02112137': 109, 'n02112350': 110, 'n02112706': 111, 'n02113023': 112, 'n02113186': 113, 'n02113624': 114, 'n02113712': 115, 'n02113799': 116, 'n02113978': 117, 'n02114367': 118, 'n02114548': 119, 'n02114712': 120, 'n02114855': 121, 'n02115641': 122, 'n02115913': 123, 'n02116738': 124, 'n02117135': 125, 'n02119022': 126, 'n02119789': 127, 'n02120079': 128, 'n02120505': 129, 'n02123045': 130, 'n02123159': 131, 'n02123394': 132, 'n02123597': 133, 'n02124075': 134, 'n02125311': 135, 'n02127052': 136, 'n02128385': 137, 'n02128757': 138, 'n02128925': 139, 'n02129165': 140, 'n02129604': 141, 'n02130308': 142, 'n02132136': 143, 'n02133161': 144, 'n02134084': 145, 'n02134418': 146, 'n02137549': 147, 'n02138441': 148, 'n02441942': 149, 'n02442845': 150, 'n02443114': 151, 'n02443484': 152, 'n02444819': 153, 'n02445715': 154, 'n02447366': 155, 'n02509815': 156, 'n02510455': 157} def get_spanning_leaves(self, node): """ Given a graph and a node return the list of all leaves spanning from the node diff --git a/main.py b/main.py index b49c4e5..cb02273 100644 --- a/main.py +++ b/main.py @@ -206,12 +206,19 @@ def generateFeatures(backbone, datasets, sample_aug=args.sample_aug): for augs in range(n_aug): features = [{"name_class": name_class, "features": []} for name_class in dataset["name_classes"]] for batchIdx, (data, target) in enumerate(dataset["dataloader"]): - if isinstance(data, dict): - data = data["supervised"] - data, target = to(data, args.device), target.to(args.device) - feats = backbone(data).to("cpu") - for i in range(feats.shape[0]): - features[target[i]]["features"].append(feats[i]) + reda = True + if reda : + if isinstance(data, dict): + data = data["supervised"] + data, target = to(data, args.device), target.to(args.device) + #print(data) + #print(data.shape) + #print("feaaaaaaaaaaaaaaaaats") + feats = backbone(data).to("cpu") + #print(feats.shape) + #print(feats) + for i in range(feats.shape[0]): + features[target[i]]["features"].append(feats[i]) for c in range(len(allFeatures)): if augs == 0: allFeatures[c]["features"] = torch.stack(features[c]["features"])/n_aug @@ -255,7 +262,7 @@ def get_optimizer(parameters, name, lr, weight_decay): import backbones backbone, outputDim = backbones.prepareBackbone() if args.load_backbone != "": - backbone.load_state_dict(torch.load(args.load_backbone)) + backbone.load_state_dict(torch.load(args.load_backbone, map_location=args.device)["state_dict"]) backbone = backbone.to(args.device) if not args.silent: numParamsBackbone = torch.tensor([m.numel() for m in backbone.parameters()]).sum().item() @@ -269,7 +276,18 @@ def get_optimizer(parameters, name, lr, weight_decay): nSteps = math.ceil(args.dataset_size / args.batch_size) except: nSteps = 0 - + + print(backbone) + #print(backbone.state_dict()) + #embed = nn.Sequential(backbone.embed) + #block_net = nn.Sequential(*list(backbone.blocks)) + #backbone = nn.Sequential(embed, block_net) + + featuresValidation = generateFeatures(backbone, validationSet) + biyam = testFewShot(featuresValidation, validationSet) + print(biyam) + print(STOP) + criterion = {} teacher = {} all_steps = [item for sublist in eval(args.steps) for item in sublist] @@ -329,6 +347,7 @@ def get_optimizer(parameters, name, lr, weight_decay): print() print(" ep. lr ".format(), end='') for dataset in trainSet: + print(Back.CYAN + " {:>19s} ".format(dataset["name"]) + Style.RESET_ALL, end='') if epoch >= args.skip_epochs: for dataset in validationSet: @@ -358,7 +377,6 @@ def get_optimizer(parameters, name, lr, weight_decay): else: raise ValueError(f"Unknown scheduler {args.scheduler}") lr = lr * args.gamma - continueTest = False meanVector = None trainStats = None @@ -402,9 +420,11 @@ def get_optimizer(parameters, name, lr, weight_decay): if continueTest: testStats = tempTestStats ender = Style.RESET_ALL - if continueTest and args.save_backbone != "" and epoch >= args.skip_epochs: - torch.save(backbone.to("cpu").state_dict(), args.save_backbone) - backbone.to(args.device) + if epoch != 0 and epoch % 10 == 0: + torch.save({"state_dict" : backbone.state_dict()}, f"backbone_{epoch}.pth") + torch.save({"state_dict" : criterion["supervised"][0].state_dict()}, f"class_meta_{epoch}.pth") + #torch.save(backbone.to("cpu").state_dict(), args.save_backbone) + #backbone.to(args.device) if continueTest and args.save_features_prefix != "" and epoch >= args.skip_epochs: for i, dataset in enumerate(trainSet): torch.save(featuresTrain[i], args.save_features_prefix + dataset["name"] + "_features.pt") @@ -427,16 +447,19 @@ def get_optimizer(parameters, name, lr, weight_decay): allRunTrainStats = torch.cat([allRunTrainStats, trainStats.unsqueeze(0)]) else: allRunTrainStats = trainStats.unsqueeze(0) + if validationSet != []: if allRunValidationStats is not None: allRunValidationStats = torch.cat([allRunValidationStats, validationStats.unsqueeze(0)]) else: + print(validationStats,"STATS") allRunValidationStats = validationStats.unsqueeze(0) if testSet != []: if allRunTestStats is not None: allRunTestStats = torch.cat([allRunTestStats, testStats.unsqueeze(0)]) else: allRunTestStats = testStats.unsqueeze(0) + print() print("Run " + str(nRun+1) + "/" + str(args.runs) + " finished") @@ -453,3 +476,7 @@ def get_optimizer(parameters, name, lr, weight_decay): print() if args.wandb!='': run_wandb.finish() + + + +