Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] implém teacher ofa #74

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial

class ConvBN2d(nn.Module):
def __init__(self, in_f, out_f, kernel_size = 3, stride = 1, padding = 1, groups = 1, outRelu = False, leaky = args.leaky):
def __init__(self, in_f, out_f, kernel_size = 1, stride = 1, padding = 0, groups = 1, outRelu = False, leaky = args.leaky):
super(ConvBN2d, self).__init__()
self.conv = nn.Conv2d(in_f, out_f, kernel_size = kernel_size, stride = stride, padding = padding, groups = groups, bias = False)
self.bn = nn.BatchNorm2d(out_f)
Expand All @@ -20,6 +20,7 @@ def forward(self, x, lbda = None, perm = None):
y = self.bn(self.conv(x))
if lbda is not None:
y = lbda * y + (1 - lbda) * y[perm]

if self.outRelu:
if not self.leaky:
return torch.relu(y)
Expand All @@ -31,13 +32,15 @@ def forward(self, x, lbda = None, perm = None):
class BasicBlock(nn.Module):
def __init__(self, in_f, out_f, stride=1, in_expansion = None):
super(BasicBlock, self).__init__()
self.convbn1 = ConvBN2d(in_f, out_f, stride = stride, outRelu = True)
self.convbn2 = ConvBN2d(out_f, out_f)
self.convbn1 = ConvBN2d(in_f, 6*stride*in_f)
self.convbn2 = ConvBN2d(6*stride*in_f, 6*stride*in_f, stride = stride, outRelu = True, kernel_size = 7, padding = 3)
self.convbn3 = ConvBN2d(6*stride*in_f, out_f)
self.shortcut = None if stride == 1 and in_f == out_f else ConvBN2d(in_f, out_f, kernel_size = 1, stride = stride, padding = 0)

def forward(self, x, lbda = None, perm = None):
y = self.convbn1(x)
z = self.convbn2(y)
y1 = self.convbn1(x)
y2 = self.convbn2(y1)
z = self.convbn3(y2)
if self.shortcut is not None:
z += self.shortcut(x)
else:
Expand Down Expand Up @@ -79,16 +82,16 @@ def __init__(self, block, blockList, featureMaps, large = False):
super(ResNet, self).__init__()
self.large = large
if not large:
self.embed = ConvBN2d(3, featureMaps, outRelu = True)
self.embed = ConvBN2d(3, featureMaps, kernel_size=7, outRelu = True)
else:
self.embed = ConvBN2d(3, featureMaps, kernel_size=7, stride=2, padding=3, outRelu = True)
self.mp = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
self.mp = nn.MaxPool2d(kernel_size = 7, stride = 2, padding = 1)
blocks = []
lastMult = 1
first = True
for (nBlocks, stride, multiplier) in blockList:
for i in range(nBlocks):
blocks.append(block(int(featureMaps * lastMult), int(featureMaps * multiplier), in_expansion = 1 if first else 4, stride = 1 if i > 0 else stride))
blocks.append(block(int(featureMaps * lastMult), int(featureMaps * multiplier), in_expansion = 1 if first else 4, stride = 1 if i != 0 else stride))
first = False
lastMult = multiplier
self.blocks = nn.ModuleList(blocks)
Expand All @@ -104,21 +107,18 @@ def forward(self, x, mixup = None, lbda = None, perm = None):
x = lbda * x + (1 - lbda) * x[perm]
if x.shape[1] == 1:
x = x.repeat(1,3,1,1)

if mixup_layer == 1:
y = self.embed(x, lbda, perm)
else:
y = self.embed(x)

if self.large:
y = self.mp(y)

for i, block in enumerate(self.blocks):
if mixup_layer == i + 2:
y = block(y, lbda, perm)
else:
y = block(y)

y = y.mean(dim = list(range(2, len(y.shape))))
return y

Expand Down Expand Up @@ -213,6 +213,7 @@ def prepareBackbone():
backbone = '_'.join(backbone.split('_')[:-1])

return {
"resnet_ofa": lambda: (ResNet(BasicBlock, [(4, 1, 1), (4, 2, 2),(4,2,8),(4,2,16)], args.feature_maps, large = False), 16 * args.feature_maps),
"resnet18": lambda: (ResNet(BasicBlock, [(2, 1, 1), (2, 2, 2), (2, 2, 4), (2, 2, 8)], args.feature_maps, large = large), 8 * args.feature_maps),
"resnet20": lambda: (ResNet(BasicBlock, [(3, 1, 1), (3, 2, 2), (3, 2, 4)], args.feature_maps, large = large), 4 * args.feature_maps),
"resnet56": lambda: (ResNet(BasicBlock, [(9, 1, 1), (9, 2, 2), (9, 2, 4)], args.feature_maps, large = large), 4 * args.feature_maps),
Expand Down
9 changes: 5 additions & 4 deletions dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ def __init__(self, data, targets, transforms, target_transforms=lambda x:x, open
self.transforms = transforms
self.target_transforms = target_transforms
self.opener = opener

def __getitem__(self, idx):
if isinstance(self.data[idx], str):
elt = self.opener(args.dataset_path + self.data[idx])
elt = self.opener("/users2/local/datasets/" + self.data[idx])
else:
elt = self.data[idx]
return self.transforms(elt), self.target_transforms(self.targets[idx])
Expand Down Expand Up @@ -206,10 +207,10 @@ def metadataset(datasetName, name):
"""
Generic function to load a dataset from the Meta-Dataset v1.0
"""
f = open(args.dataset_path + "datasets.json")
f = open("/users2/local/datasets/" + "datasets.json")
all_datasets = json.loads(f.read())
f.close()
dataset = all_datasets[name+"_" + datasetName]
dataset = all_datasets[name + "_" + datasetName]
if datasetName == "train":
image_size = args.training_image_size if args.training_image_size>0 else 126
else:
Expand All @@ -220,7 +221,7 @@ def metadataset(datasetName, name):
else:
default_test_transforms = ['metadatasettotensor', 'randomresizedcrop', 'biresize', 'metadatasetnorm']
trans = get_transforms(image_size, datasetName, default_train_transforms, default_test_transforms)
return {"dataloader": dataLoader(DataHolder(dataset["data"], dataset["targets"], trans), shuffle = datasetName == "train", episodic=args.episodic and datasetName == "train", datasetName=name+"_"+datasetName), "name":dataset["name"], "num_classes":dataset["num_classes"], "name_classes": dataset["name_classes"]}
return {"dataloader": dataLoader(DataHolder(dataset["data"], dataset["targets"], trans), shuffle = True, episodic=args.episodic and datasetName == "train", datasetName=name+"_"+datasetName), "name":dataset["name"], "num_classes":dataset["num_classes"], "name_classes": dataset["name_classes"]}

def metadataset_imagenet_v2():
f = open(args.dataset_path + "datasets.json")
Expand Down
2 changes: 1 addition & 1 deletion few_shot_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(self, **kwargs):

self.graph_map = {node['wn_id']:node['children_ids'] for node in self.graph}
self.node_candidates = [node for node in self.graph_map.keys() if 5<=len(self.get_spanning_leaves(node))<=392]
self.classIdx = self.dataset["classIdx"]
self.classIdx = {'n02085620': 0, 'n02085782': 1, 'n02085936': 2, 'n02086079': 3, 'n02086240': 4, 'n02086646': 5, 'n02086910': 6, 'n02087046': 7, 'n02087394': 8, 'n02088094': 9, 'n02088238': 10, 'n02088364': 11, 'n02088466': 12, 'n02088632': 13, 'n02089078': 14, 'n02089867': 15, 'n02089973': 16, 'n02090379': 17, 'n02090622': 18, 'n02090721': 19, 'n02091032': 20, 'n02091134': 21, 'n02091244': 22, 'n02091467': 23, 'n02091635': 24, 'n02091831': 25, 'n02092002': 26, 'n02092339': 27, 'n02093256': 28, 'n02093428': 29, 'n02093647': 30, 'n02093754': 31, 'n02093859': 32, 'n02093991': 33, 'n02094114': 34, 'n02094258': 35, 'n02094433': 36, 'n02095314': 37, 'n02095570': 38, 'n02095889': 39, 'n02096051': 40, 'n02096177': 41, 'n02096294': 42, 'n02096437': 43, 'n02096585': 44, 'n02097047': 45, 'n02097130': 46, 'n02097209': 47, 'n02097298': 48, 'n02097474': 49, 'n02097658': 50, 'n02098105': 51, 'n02098286': 52, 'n02098413': 53, 'n02099267': 54, 'n02099429': 55, 'n02099601': 56, 'n02099712': 57, 'n02099849': 58, 'n02100236': 59, 'n02100583': 60, 'n02100735': 61, 'n02100877': 62, 'n02101006': 63, 'n02101388': 64, 'n02101556': 65, 'n02102040': 66, 'n02102177': 67, 'n02102318': 68, 'n02102480': 69, 'n02102973': 70, 'n02104029': 71, 'n02104365': 72, 'n02105056': 73, 'n02105162': 74, 'n02105251': 75, 'n02105412': 76, 'n02105505': 77, 'n02105641': 78, 'n02105855': 79, 'n02106030': 80, 'n02106166': 81, 'n02106382': 82, 'n02106550': 83, 'n02106662': 84, 'n02107142': 85, 'n02107312': 86, 'n02107574': 87, 'n02107683': 88, 'n02107908': 89, 'n02108000': 90, 'n02108089': 91, 'n02108422': 92, 'n02108551': 93, 'n02108915': 94, 'n02109047': 95, 'n02109525': 96, 'n02109961': 97, 'n02110063': 98, 'n02110185': 99, 'n02110341': 100, 'n02110627': 101, 'n02110806': 102, 'n02110958': 103, 'n02111129': 104, 'n02111277': 105, 'n02111500': 106, 'n02111889': 107, 'n02112018': 108, 'n02112137': 109, 'n02112350': 110, 'n02112706': 111, 'n02113023': 112, 'n02113186': 113, 'n02113624': 114, 'n02113712': 115, 'n02113799': 116, 'n02113978': 117, 'n02114367': 118, 'n02114548': 119, 'n02114712': 120, 'n02114855': 121, 'n02115641': 122, 'n02115913': 123, 'n02116738': 124, 'n02117135': 125, 'n02119022': 126, 'n02119789': 127, 'n02120079': 128, 'n02120505': 129, 'n02123045': 130, 'n02123159': 131, 'n02123394': 132, 'n02123597': 133, 'n02124075': 134, 'n02125311': 135, 'n02127052': 136, 'n02128385': 137, 'n02128757': 138, 'n02128925': 139, 'n02129165': 140, 'n02129604': 141, 'n02130308': 142, 'n02132136': 143, 'n02133161': 144, 'n02134084': 145, 'n02134418': 146, 'n02137549': 147, 'n02138441': 148, 'n02441942': 149, 'n02442845': 150, 'n02443114': 151, 'n02443484': 152, 'n02444819': 153, 'n02445715': 154, 'n02447366': 155, 'n02509815': 156, 'n02510455': 157}
def get_spanning_leaves(self, node):
"""
Given a graph and a node return the list of all leaves spanning from the node
Expand Down
51 changes: 39 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,19 @@ def generateFeatures(backbone, datasets, sample_aug=args.sample_aug):
for augs in range(n_aug):
features = [{"name_class": name_class, "features": []} for name_class in dataset["name_classes"]]
for batchIdx, (data, target) in enumerate(dataset["dataloader"]):
if isinstance(data, dict):
data = data["supervised"]
data, target = to(data, args.device), target.to(args.device)
feats = backbone(data).to("cpu")
for i in range(feats.shape[0]):
features[target[i]]["features"].append(feats[i])
reda = True
if reda :
if isinstance(data, dict):
data = data["supervised"]
data, target = to(data, args.device), target.to(args.device)
#print(data)
#print(data.shape)
#print("feaaaaaaaaaaaaaaaaats")
feats = backbone(data).to("cpu")
#print(feats.shape)
#print(feats)
for i in range(feats.shape[0]):
features[target[i]]["features"].append(feats[i])
for c in range(len(allFeatures)):
if augs == 0:
allFeatures[c]["features"] = torch.stack(features[c]["features"])/n_aug
Expand Down Expand Up @@ -255,7 +262,7 @@ def get_optimizer(parameters, name, lr, weight_decay):
import backbones
backbone, outputDim = backbones.prepareBackbone()
if args.load_backbone != "":
backbone.load_state_dict(torch.load(args.load_backbone))
backbone.load_state_dict(torch.load(args.load_backbone, map_location=args.device)["state_dict"])
backbone = backbone.to(args.device)
if not args.silent:
numParamsBackbone = torch.tensor([m.numel() for m in backbone.parameters()]).sum().item()
Expand All @@ -269,7 +276,18 @@ def get_optimizer(parameters, name, lr, weight_decay):
nSteps = math.ceil(args.dataset_size / args.batch_size)
except:
nSteps = 0


print(backbone)
#print(backbone.state_dict())
#embed = nn.Sequential(backbone.embed)
#block_net = nn.Sequential(*list(backbone.blocks))
#backbone = nn.Sequential(embed, block_net)

featuresValidation = generateFeatures(backbone, validationSet)
biyam = testFewShot(featuresValidation, validationSet)
print(biyam)
print(STOP)

criterion = {}
teacher = {}
all_steps = [item for sublist in eval(args.steps) for item in sublist]
Expand Down Expand Up @@ -329,6 +347,7 @@ def get_optimizer(parameters, name, lr, weight_decay):
print()
print(" ep. lr ".format(), end='')
for dataset in trainSet:

print(Back.CYAN + " {:>19s} ".format(dataset["name"]) + Style.RESET_ALL, end='')
if epoch >= args.skip_epochs:
for dataset in validationSet:
Expand Down Expand Up @@ -358,7 +377,6 @@ def get_optimizer(parameters, name, lr, weight_decay):
else:
raise ValueError(f"Unknown scheduler {args.scheduler}")
lr = lr * args.gamma

continueTest = False
meanVector = None
trainStats = None
Expand Down Expand Up @@ -402,9 +420,11 @@ def get_optimizer(parameters, name, lr, weight_decay):
if continueTest:
testStats = tempTestStats
ender = Style.RESET_ALL
if continueTest and args.save_backbone != "" and epoch >= args.skip_epochs:
torch.save(backbone.to("cpu").state_dict(), args.save_backbone)
backbone.to(args.device)
if epoch != 0 and epoch % 10 == 0:
torch.save({"state_dict" : backbone.state_dict()}, f"backbone_{epoch}.pth")
torch.save({"state_dict" : criterion["supervised"][0].state_dict()}, f"class_meta_{epoch}.pth")
#torch.save(backbone.to("cpu").state_dict(), args.save_backbone)
#backbone.to(args.device)
if continueTest and args.save_features_prefix != "" and epoch >= args.skip_epochs:
for i, dataset in enumerate(trainSet):
torch.save(featuresTrain[i], args.save_features_prefix + dataset["name"] + "_features.pt")
Expand All @@ -427,16 +447,19 @@ def get_optimizer(parameters, name, lr, weight_decay):
allRunTrainStats = torch.cat([allRunTrainStats, trainStats.unsqueeze(0)])
else:
allRunTrainStats = trainStats.unsqueeze(0)

if validationSet != []:
if allRunValidationStats is not None:
allRunValidationStats = torch.cat([allRunValidationStats, validationStats.unsqueeze(0)])
else:
print(validationStats,"STATS")
allRunValidationStats = validationStats.unsqueeze(0)
if testSet != []:
if allRunTestStats is not None:
allRunTestStats = torch.cat([allRunTestStats, testStats.unsqueeze(0)])
else:
allRunTestStats = testStats.unsqueeze(0)


print()
print("Run " + str(nRun+1) + "/" + str(args.runs) + " finished")
Expand All @@ -453,3 +476,7 @@ def get_optimizer(parameters, name, lr, weight_decay):
print()
if args.wandb!='':
run_wandb.finish()