Skip to content

Commit

Permalink
feat: support distillation strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
The-truthh committed Aug 8, 2023
1 parent 85450da commit 5a2f0a7
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 2 deletions.
14 changes: 14 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,20 @@ def create_parser():
group.add_argument('--drop_overflow_update', type=bool, default=False,
help='Whether to execute optimizer if there is an overflow (default=False)')

# distillation
group = parser.add_argument_group('Distillation parameters')
group.add_argument('--distillation_type', type=str, default=None,
choices=['hard', 'soft'],
help='The type of distillation (default=None)')
group.add_argument('--teacher_model', type=str, default=None,
help='Name of teacher model (default=None)')
group.add_argument('--teacher_ckpt_path', type=str, default='',
help='Initialize teacher model from this checkpoint. '
'If resume training, specify the checkpoint path (default="").')
group.add_argument('--distillation_alpha', type=float, default=0.5,
help='The coefficient balancing the distillation loss and base loss'
'(default=0.5)')

# modelarts
group = parser.add_argument_group('modelarts')
group.add_argument('--enable_modelarts', type=str2bool, nargs='?', const=True, default=False,
Expand Down
1 change: 1 addition & 0 deletions mindcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .amp import *
from .callbacks import *
from .checkpoint_manager import *
from .distillation import *
from .download import *
from .logger import *
from .path import *
Expand Down
87 changes: 87 additions & 0 deletions mindcv/utils/distillation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
""" distillation loss cell define """
from types import MethodType

import mindspore as ms
from mindspore import nn
from mindspore.ops import functional as F


class DistillLossCell(nn.WithLossCell):
"""
Wraps the network with hard distillation loss function.
Get the loss of student network and an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision.
Args:
backbone (Cell): The student network to train and calculate base loss.
loss_fn (Cell): The loss function used to compute loss of student network.
distillation_type (str): The type of distillation.
teacher_model (Cell): The teacher network to calculate distillation loss.
alpha (float): The coefficient to balance the distillation loss and base loss. Default: 0.5.
tau (float): Distillation temperature. The higher the temperature, the lower the
dispersion of the loss calculated by Kullback-Leibler divergence loss. Default: 1.0.
"""

def __init__(self, backbone, loss_fn, distillation_type, teacher_model, alpha=0.5, tau=1.0):
super().__init__(backbone, loss_fn)
if distillation_type == "hard":
self.hard_type = True
elif distillation_type == "soft":
self.hard_type = False
else:
raise ValueError(f"Distillation type only support ['hard', 'soft'], but got {distillation_type}.")
self.teacher_model = teacher_model
self.alpha = alpha
self.tau = tau

def construct(self, data, label):
out = self._backbone(data)

out, out_kd = out
base_loss = self._loss_fn(out, label)

teacher_out = F.stop_gradient(self.teacher_model(data))

if self.hard_type:
distillation_loss = F.cross_entropy(out_kd, teacher_out.argmax(axis=1))
else:
T = self.tau
out_kd = F.cast(out_kd, ms.float32)
distillation_loss = (
F.kl_div(
F.log_softmax(out_kd / T, axis=1),
F.log_softmax(teacher_out / T, axis=1),
reduction="sum",
)
* (T * T)
/ F.size(out_kd)
)

loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha

return loss


def set_validation(network):
r"""
Since MindSpore cannot automatically set some cells to validation mode
during training in the teacher network, we need to manually set these
cells to validation mode in this function.
"""

for _, cell in network.cells_and_names():
if isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d, nn.BatchNorm3d)):
cell.construct = MethodType(bn_infer_only, cell)
elif isinstance(cell, nn.Dropout):
cell.construct = MethodType(dropout_infer_only, cell)
else:
cell.set_train(False)


def bn_infer_only(self, x):
return self.bn_infer(x, self.gamma, self.beta, self.moving_mean, self.moving_variance)[0]


def dropout_infer_only(self, x):
return x
17 changes: 15 additions & 2 deletions mindcv/utils/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model

from .amp import auto_mixed_precision
from .distillation import DistillLossCell
from .train_step import TrainStep

__all__ = [
Expand Down Expand Up @@ -38,6 +39,7 @@ def require_customized_train_step(
clip_grad: bool = False,
gradient_accumulation_steps: int = 1,
amp_cast_list: Optional[str] = None,
distillation_type: Optional[str] = None,
):
if ema:
return True
Expand All @@ -47,6 +49,8 @@ def require_customized_train_step(
return True
if amp_cast_list:
return True
if distillation_type:
return True
return False


Expand Down Expand Up @@ -88,6 +92,9 @@ def create_trainer(
clip_grad: bool = False,
clip_value: float = 15.0,
gradient_accumulation_steps: int = 1,
distillation_type: Optional[str] = None,
teacher_network: Optional[nn.Cell] = None,
distillation_alpha: float = 0.5,
):
"""Create Trainer.
Expand All @@ -106,6 +113,9 @@ def create_trainer(
clip_grad: whether to gradient clip.
clip_value: The value at which to clip gradients.
gradient_accumulation_steps: Accumulate the gradients of n batches before update.
distillation_type: The type of distillation.
teacher_network: The teacher network for distillation.
distillation_alpha: The coefficient to balance the distillation loss and base loss.
Returns:
mindspore.Model
Expand All @@ -120,7 +130,7 @@ def create_trainer(
if gradient_accumulation_steps < 1:
raise ValueError("`gradient_accumulation_steps` must be >= 1!")

if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list):
if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list, distillation_type):
mindspore_kwargs = dict(
network=network,
loss_fn=loss,
Expand Down Expand Up @@ -149,7 +159,10 @@ def create_trainer(
else: # require customized train step
eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"])
auto_mixed_precision(network, amp_level, amp_cast_list)
net_with_loss = add_loss_network(network, loss, amp_level)
if distillation_type:
net_with_loss = DistillLossCell(network, loss, distillation_type, teacher_network, distillation_alpha)
else:
net_with_loss = add_loss_network(network, loss, amp_level)
train_step_kwargs = dict(
network=net_with_loss,
optimizer=optimizer,
Expand Down
17 changes: 17 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
require_customized_train_step,
set_logger,
set_seed,
set_validation,
)

from config import parse_args, save_args # isort: skip
Expand Down Expand Up @@ -180,6 +181,18 @@ def train(args):
aux_factor=args.aux_factor,
)

# create teacher model
teacher_network = None
if args.distillation_type:
if not args.teacher_ckpt_path:
logger.warning("You are using distillation, but your teacher model has not loaded weights.")
teacher_network = create_model(
model_name=args.teacher_model,
num_classes=num_classes,
checkpoint_path=args.teacher_ckpt_path,
)
set_validation(teacher_network)

# create learning rate schedule
lr_scheduler = create_scheduler(
num_batches,
Expand Down Expand Up @@ -213,6 +226,7 @@ def train(args):
args.clip_grad,
args.gradient_accumulation_steps,
args.amp_cast_list,
args.distillation_type,
)
):
optimizer_loss_scale = args.loss_scale
Expand Down Expand Up @@ -250,6 +264,9 @@ def train(args):
clip_grad=args.clip_grad,
clip_value=args.clip_value,
gradient_accumulation_steps=args.gradient_accumulation_steps,
distillation_type=args.distillation_type,
teacher_network=teacher_network,
distillation_alpha=args.distillation_alpha,
)

# callback
Expand Down

0 comments on commit 5a2f0a7

Please sign in to comment.