Skip to content

Commit

Permalink
Fix AMP logic + disable gradscaler for bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 11, 2024
1 parent 52a67f9 commit c7198ad
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 75 deletions.
2 changes: 2 additions & 0 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ factor_args = FactorArguments(
strategy="ekfac", # Choose from "identity", "diagonal", "kfac", or "ekfac".
use_empirical_fisher=False,
amp_dtype=None,
amp_scale=2.0**16,
has_shared_parameters=False,

# Settings for covariance matrix fitting.
Expand Down Expand Up @@ -236,6 +237,7 @@ You can change:
- `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch)
instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
- `amp_scale`: Sets the scale factor for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html).
- `has_shared_parameters`: Specifies whether the shared parameters exist in the forward pass.

### Fitting Covariance Matrices
Expand Down
4 changes: 4 additions & 0 deletions kronfluence/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class FactorArguments(Arguments):
default=None,
metadata={"help": "Data type for automatic mixed precision (AMP). If `None`, AMP is disabled."},
)
amp_scale: float = field(
default=2.0**16,
metadata={"help": "Scale factor for AMP (only applicable when `amp_dtype=torch.float16`)."},
)
has_shared_parameters: bool = field(
default=False,
metadata={"help": "Indicates whether shared parameters are present in the model's forward pass."},
Expand Down
7 changes: 4 additions & 3 deletions kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ def fit_covariance_matrices_with_loader(
total_steps = 0
num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False)
enable_amp = factor_args.amp_dtype is not None
scaler = GradScaler(enabled=enable_amp)
if enable_amp:
enable_grad_scaler = enable_amp and factor_args.amp_dtype == torch.float16
scaler = GradScaler(init_scale=factor_args.amp_scale, enabled=enable_grad_scaler)
if enable_grad_scaler:
gradient_scale = 1.0 / scaler.get_scale()
set_gradient_scale(model=model, gradient_scale=gradient_scale)

Expand Down Expand Up @@ -257,7 +258,7 @@ def fit_covariance_matrices_with_loader(

model.zero_grad(set_to_none=True)
set_attention_mask(model=model, attention_mask=None)
if enable_amp:
if enable_grad_scaler:
set_gradient_scale(model=model, gradient_scale=1.0)
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
state.wait_for_everyone()
Expand Down
7 changes: 4 additions & 3 deletions kronfluence/factor/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,9 @@ def fit_lambda_matrices_with_loader(
total_steps = 0
num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False)
enable_amp = factor_args.amp_dtype is not None
scaler = GradScaler(enabled=enable_amp)
if enable_amp:
enable_grad_scaler = enable_amp and factor_args.amp_dtype == torch.float16
scaler = GradScaler(init_scale=factor_args.amp_scale, enabled=enable_grad_scaler)
if enable_grad_scaler:
gradient_scale = 1.0 / scaler.get_scale()
set_gradient_scale(model=model, gradient_scale=gradient_scale)

Expand Down Expand Up @@ -453,7 +454,7 @@ def fit_lambda_matrices_with_loader(
saved_factors[factor_name] = factor

model.zero_grad(set_to_none=True)
if enable_amp:
if enable_grad_scaler:
set_gradient_scale(model=model, gradient_scale=1.0)
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
state.wait_for_everyone()
Expand Down
22 changes: 0 additions & 22 deletions kronfluence/module/tracker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,6 @@ def _raise_cache_not_found_exception(self) -> None:
f"For case 2, set 'has_shared_parameters=True' to enable parameter sharing."
)

def _preprocess_gradient(self, output_gradient: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor:
"""Preprocesses the output gradient.
Args:
output_gradient (torch.Tensor):
The original output gradient.
target_dtype (torch.dtype):
The desired data type for the gradient tensor.
Returns:
torch.Tensor:
The preprocessed gradient.
"""
original_dtype = output_gradient.dtype
output_gradient = output_gradient.to(dtype=target_dtype)
if self.module.gradient_scale != 1.0:
if original_dtype != target_dtype:
output_gradient.mul_(self.module.gradient_scale)
else:
output_gradient = output_gradient * self.module.gradient_scale
return output_gradient

def register_hooks(self) -> None:
"""Registers hooks for the module."""

Expand Down
21 changes: 11 additions & 10 deletions kronfluence/module/tracker/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def _update_gradient_covariance_matrix(
)
self._gradient_covariance_initialized = True
self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED].add_(count)
self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(output_gradient.t(), output_gradient)
alpha = 1
if self.module.gradient_scale != 1.0:
alpha = self.module.gradient_scale**2.0
self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(output_gradient.t(), output_gradient, alpha=alpha)

def register_hooks(self) -> None:
"""Sets up hooks to compute activation and gradient covariance matrices."""
Expand All @@ -112,9 +115,7 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.
def backward_hook(output_gradient: torch.Tensor) -> None:
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient.detach(), target_dtype=self.module.factor_args.gradient_covariance_dtype
)
output_gradient = output_gradient.detach().to(dtype=self.module.factor_args.gradient_covariance_dtype)
# Computes and updates pseudo-gradient covariance during backward pass.
output_gradient, count = self.module.get_flattened_gradient(output_gradient=output_gradient)
self._update_gradient_covariance_matrix(output_gradient=output_gradient, count=count)
Expand Down Expand Up @@ -259,25 +260,23 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._raise_cache_not_found_exception()
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient=output_gradient.detach(), target_dtype=self.module.factor_args.per_sample_gradient_dtype
)
output_gradient = output_gradient.detach().to(dtype=self.module.factor_args.per_sample_gradient_dtype)
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=self.cached_activations.to(device=output_gradient.device),
output_gradient=output_gradient,
).to(dtype=self.module.factor_args.lambda_dtype)
self.clear_all_cache()
del output_gradient
if self.module.gradient_scale != 1.0:
per_sample_gradient.mul_(self.module.gradient_scale)
# Computes and updates lambda matrix during backward pass.
self._update_lambda_matrix(per_sample_gradient=per_sample_gradient)

@torch.no_grad()
def shared_backward_hook(output_gradient: torch.Tensor) -> None:
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient=output_gradient.detach(), target_dtype=self.module.factor_args.per_sample_gradient_dtype
)
output_gradient = output_gradient.detach().to(dtype=self.module.factor_args.per_sample_gradient_dtype)
cached_activation = self.cached_activations.pop()
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=cached_activation.to(device=output_gradient.device),
Expand All @@ -297,6 +296,8 @@ def finalize_iteration(self) -> None:
self.cached_per_sample_gradient = self.cached_per_sample_gradient.to(
dtype=self.module.factor_args.lambda_dtype
)
if self.module.gradient_scale != 1.0:
self.cached_per_sample_gradient.mul_(self.module.gradient_scale)
self._update_lambda_matrix(per_sample_gradient=self.cached_per_sample_gradient)
self.clear_all_cache()

Expand Down
6 changes: 3 additions & 3 deletions kronfluence/module/tracker/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._raise_cache_not_found_exception()
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
)
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.per_sample_gradient_dtype)
if isinstance(self.cached_activations, list):
cached_activation = self.cached_activations.pop()
else:
Expand All @@ -56,6 +54,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
input_activation=cached_activation.to(device=output_gradient.device),
output_gradient=output_gradient,
).sum(dim=0, keepdim=True)
if self.module.gradient_scale != 1.0:
summed_gradient.mul_(self.module.gradient_scale)
if self.module.storage[AGGREGATED_GRADIENT_NAME] is None:
self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros_like(summed_gradient, requires_grad=False)
self.module.storage[AGGREGATED_GRADIENT_NAME].add_(summed_gradient)
Expand Down
8 changes: 5 additions & 3 deletions kronfluence/module/tracker/pairwise_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._raise_cache_not_found_exception()
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient.detach(), target_dtype=self.module.score_args.score_dtype
)
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.score_dtype)
if isinstance(self.cached_activations, list):
cached_activation = self.cached_activations.pop()
else:
Expand All @@ -90,6 +88,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
input_activation=cached_activation.to(device=output_gradient.device),
output_gradient=output_gradient,
)
if self.module.gradient_scale != 1.0:
self.module.storage[PAIRWISE_SCORE_MATRIX_NAME].mul_(self.module.gradient_scale)
del cached_activation, output_gradient
self.clear_all_cache()
else:
Expand All @@ -98,6 +98,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
output_gradient=output_gradient,
)
del cached_activation, output_gradient
if self.module.gradient_scale != 1.0:
per_sample_gradient.mul_(self.module.gradient_scale)
self._compute_pairwise_score_with_gradient(per_sample_gradient=per_sample_gradient)

self.registered_hooks.append(self.module.register_forward_hook(forward_hook))
Expand Down
12 changes: 6 additions & 6 deletions kronfluence/module/tracker/precondition.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._raise_cache_not_found_exception()
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient=output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
)
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.per_sample_gradient_dtype)
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=self.cached_activations.to(device=output_gradient.device),
output_gradient=output_gradient,
Expand All @@ -119,16 +117,16 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
gradient=per_sample_gradient,
storage=self.module.storage,
)
if self.module.gradient_scale != 1.0:
preconditioned_gradient.mul_(self.module.gradient_scale)
del per_sample_gradient
self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient)

@torch.no_grad()
def shared_backward_hook(output_gradient: torch.Tensor) -> None:
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient=output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
)
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.per_sample_gradient_dtype)
cached_activation = self.cached_activations.pop()
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=cached_activation.to(device=output_gradient.device),
Expand All @@ -153,6 +151,8 @@ def finalize_iteration(self) -> None:
storage=self.module.storage,
)
self.cached_per_sample_gradient = None
if self.module.gradient_scale != 1.0:
preconditioned_gradient.mul_(self.module.gradient_scale)
self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient)
self.clear_all_cache()

Expand Down
20 changes: 11 additions & 9 deletions kronfluence/module/tracker/self_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,24 +91,22 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._raise_cache_not_found_exception()
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
)
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.per_sample_gradient_dtype)
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=self.cached_activations.to(device=output_gradient.device),
output_gradient=output_gradient,
).to(dtype=self.module.score_args.precondition_dtype)
self.clear_all_cache()
del output_gradient
if self.module.gradient_scale != 1.0:
per_sample_gradient.mul_(self.module.gradient_scale)
self._compute_self_score(per_sample_gradient=per_sample_gradient)

@torch.no_grad()
def shared_backward_hook(output_gradient: torch.Tensor) -> None:
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
)
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.per_sample_gradient_dtype)
cached_activation = self.cached_activations.pop()
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=cached_activation.to(device=output_gradient.device),
Expand All @@ -127,6 +125,8 @@ def finalize_iteration(self) -> None:
self.cached_per_sample_gradient = self.cached_per_sample_gradient.to(
dtype=self.module.score_args.precondition_dtype
)
if self.module.gradient_scale != 1.0:
self.cached_per_sample_gradient.mul_(self.module.gradient_scale)
self._compute_self_score(per_sample_gradient=self.cached_per_sample_gradient)
self.clear_all_cache()

Expand Down Expand Up @@ -202,9 +202,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:

handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient.detach(), target_dtype=self.module.score_args.score_dtype
)
output_gradient = output_gradient.detach().to(dtype=self.module.score_args.score_dtype)
if isinstance(self.cached_activations, list):
cached_activation = self.cached_activations.pop()
else:
Expand All @@ -217,6 +215,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
)
self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None
self.clear_all_cache()
if self.module.gradient_scale != 1.0:
scores.mul_(self.module.gradient_scale)
if self.module.storage[SELF_SCORE_VECTOR_NAME] is None:
self.module.storage[SELF_SCORE_VECTOR_NAME] = scores
else:
Expand All @@ -227,6 +227,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
output_gradient=output_gradient,
)
del cached_activation, output_gradient
if self.module.gradient_scale != 1.0:
per_sample_gradient.mul_(self.module.gradient_scale)
self._compute_self_measurement_score_with_gradient(per_sample_gradient=per_sample_gradient)

self.registered_hooks.append(self.module.register_forward_hook(forward_hook))
Expand Down
14 changes: 8 additions & 6 deletions kronfluence/score/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,9 @@ def compute_pairwise_scores_with_loaders(
query_iter = iter(query_loader)
num_accumulations = 0
enable_amp = score_args.amp_dtype is not None
scaler = GradScaler(enabled=enable_amp)
if enable_amp:
enable_grad_scaler = enable_amp and factor_args.amp_dtype == torch.float16
scaler = GradScaler(init_scale=factor_args.amp_scale, enabled=enable_grad_scaler)
if enable_grad_scaler:
gradient_scale = 1.0 / scaler.get_scale()
set_gradient_scale(model=model, gradient_scale=gradient_scale)

Expand Down Expand Up @@ -283,7 +284,7 @@ def compute_pairwise_scores_with_loaders(
total_scores_chunks[module_name] = torch.cat(total_scores_chunks[module_name], dim=0)

model.zero_grad(set_to_none=True)
if enable_amp:
if enable_grad_scaler:
set_gradient_scale(model=model, gradient_scale=1.0)
finalize_all_iterations(model=model, tracked_module_names=tracked_module_names)
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
Expand Down Expand Up @@ -324,8 +325,9 @@ def compute_pairwise_query_aggregated_scores_with_loaders(
prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device)

enable_amp = score_args.amp_dtype is not None
scaler = GradScaler(enabled=enable_amp)
if enable_amp:
enable_grad_scaler = enable_amp and factor_args.amp_dtype == torch.float16
scaler = GradScaler(init_scale=factor_args.amp_scale, enabled=enable_grad_scaler)
if enable_grad_scaler:
gradient_scale = 1.0 / scaler.get_scale()
set_gradient_scale(model=model, gradient_scale=gradient_scale)

Expand Down Expand Up @@ -383,7 +385,7 @@ def compute_pairwise_query_aggregated_scores_with_loaders(
)

model.zero_grad(set_to_none=True)
if enable_amp:
if enable_grad_scaler:
set_gradient_scale(model=model, gradient_scale=1.0)
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
state.wait_for_everyone()
Expand Down
Loading

0 comments on commit c7198ad

Please sign in to comment.