Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[wip] hooks
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Dec 26, 2023
1 parent 31fba04 commit 3fe1055
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 2 deletions.
7 changes: 7 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,10 @@
# according to their microbatching/pipeline parallel setup.
# Note: this is currently a global flag for simplicity and dynamo performance.
weight_cache_enabled = False

#
# Other
#

# If True, dynamic linear uses hooks for activation casting
dynamic_use_activation_hooks = True
40 changes: 38 additions & 2 deletions float8_experimental/dynamic_linear/dynamic_float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
import float8_experimental.config as config


class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
Expand Down Expand Up @@ -38,6 +39,24 @@ def backward(ctx, gradY):
None,
)

def cast_x_to_float8_e4m3fn_pre_hook(module, args):
"""
Hook to cast the incoming activation to `torch.float8_e4m3fn`
"""
return module.cast_to_float8(args[0])

def cast_dldy_to_float8_e5m2_pre_hook(module, grad_output):
"""
Hook to cast the incoming gradient to `torch.float8_e5m2`
"""
gradY = grad_output[0]
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
gradY_scaled = gradY * gradY_scale
bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2)
gradY_fp8 = Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=module.emulate)
# TODO fix: the next op in the backward does not see this, it sees grad_output[0]
return (gradY_fp8,)


class Float8DynamicLinear(torch.nn.Linear):
"""
Expand All @@ -48,9 +67,16 @@ class Float8DynamicLinear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_weight_tag()
self.use_activation_hooks = config.dynamic_use_activation_hooks

def forward(self, x):
x_fp8 = self.cast_to_float8(x)
# cast x to float8_e4m3fn
if self.use_activation_hooks:
x_fp8 = x
else:
x_fp8 = self.cast_to_float8(x)

# cast w to float8_e4m3fn
if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast
w_fp8 = self._w_fp8
else:
Expand All @@ -59,7 +85,10 @@ def forward(self, x):
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)

# Cast gradY to float8_e5m2 during backward
y = self.cast_to_float8e5m2_bw(y)
if self.use_activation_hooks:
pass
else:
y = self.cast_to_float8e5m2_bw(y)

return y

Expand All @@ -69,6 +98,7 @@ def add_weight_tag(self):
self.weight._is_fp8_weight = True

def cast_to_float8(self, inpt_tensor):
# TODO rename this function to clarify e4m3
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
Expand All @@ -92,4 +122,10 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.bias = mod.bias
new_mod.emulate = emulate
new_mod.add_weight_tag()

new_mod.use_activation_hooks = config.dynamic_use_activation_hooks
if new_mod.use_activation_hooks:
# install the hooks
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
new_mod.register_full_backward_pre_hook(cast_dldy_to_float8_e5m2_pre_hook)
return new_mod
1 change: 1 addition & 0 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):

@implements([aten.mm.default])
def float8_mm(aten_op, args, kwargs=None):
print('float8_mm', args)
assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
a = args[0]
b = args[1]
Expand Down
32 changes: 32 additions & 0 deletions test/test_bw_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.nn as nn

class TestAutogradFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, tensor):
tensor = tensor + 1.0
return tensor

@staticmethod
def backward(ctx, gradY):
# prints a tensor filled with 0.123, as expected
print('gradY', gradY)
gradY = gradY + 1.0
return gradY

class M(nn.Module):
def forward(self, x):
return TestAutogradFunction.apply(x)

m = M()

def bw_pre_hook(module, go):
new_go = torch.empty_like(go[0]).fill_(0.123)
return (new_go,)

m.register_full_backward_pre_hook(bw_pre_hook)

x = torch.randn(2, 2).requires_grad_()
y = m(x)
y.sum().backward()

0 comments on commit 3fe1055

Please sign in to comment.