Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: port to torch 1.13 #24

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions center_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import warnings

import torch
import torch.nn as nn
from torch import nn
from torch.nn import functional as F


class CenterLoss(nn.Module):
"""Center loss.
Expand All @@ -11,34 +15,22 @@ class CenterLoss(nn.Module):
num_classes (int): number of classes.
feat_dim (int): feature dimension.
"""
def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
def __init__(self, num_classes: int = 10, feat_dim: int = 2, use_gpu: bool = None, clamp: int = 1e-12):
super(CenterLoss, self).__init__()
if use_gpu is not None:
warnings.warning(f"Ignoring explicitly set {use_gpu=}. Move the model via .to(device)")
self.num_classes = num_classes
self.feat_dim = feat_dim
self.use_gpu = use_gpu
self.clamp = clamp

if self.use_gpu:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
else:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

def forward(self, x, labels):
"""
Args:
x: feature matrix with shape (batch_size, feat_dim).
labels: ground truth labels with shape (batch_size).
"""
batch_size = x.size(0)
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
distmat.addmm_(1, -2, x, self.centers.t())

classes = torch.arange(self.num_classes).long()
if self.use_gpu: classes = classes.cuda()
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
mask = labels.eq(classes.expand(batch_size, self.num_classes))

dist = distmat * mask.float()
loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

return loss
centers = torch.index_select(self.centers, 0, labels.view(-1)) # [Classes, Features] (gather) [Batch] -> [Batch, Features]
return F.mse_loss(x, centers) * self.feat_dim # mean across all axes except features