diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 0f8b96be..f37b94c4 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -16,3 +16,10 @@ # according to their microbatching/pipeline parallel setup. # Note: this is currently a global flag for simplicity and dynamo performance. weight_cache_enabled = False + +# +# Other +# + +# If True, dynamic linear uses hooks for activation casting +dynamic_use_activation_hooks = True diff --git a/float8_experimental/dynamic_linear/dynamic_float8_linear.py b/float8_experimental/dynamic_linear/dynamic_float8_linear.py index f0c6a239..d1548d9a 100644 --- a/float8_experimental/dynamic_linear/dynamic_float8_linear.py +++ b/float8_experimental/dynamic_linear/dynamic_float8_linear.py @@ -11,6 +11,7 @@ from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated +import float8_experimental.config as config class NoopFwToFloat8E5M2Bw(torch.autograd.Function): @@ -38,6 +39,24 @@ def backward(ctx, gradY): None, ) +def cast_x_to_float8_e4m3fn_pre_hook(module, args): + """ + Hook to cast the incoming activation to `torch.float8_e4m3fn` + """ + return module.cast_to_float8(args[0]) + +def cast_dldy_to_float8_e5m2_pre_hook(module, grad_output): + """ + Hook to cast the incoming gradient to `torch.float8_e5m2` + """ + gradY = grad_output[0] + gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2) + gradY_scaled = gradY * gradY_scale + bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2) + gradY_fp8 = Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=module.emulate) + # TODO fix: the next op in the backward does not see this, it sees grad_output[0] + return (gradY_fp8,) + class Float8DynamicLinear(torch.nn.Linear): """ @@ -48,9 +67,16 @@ class Float8DynamicLinear(torch.nn.Linear): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.add_weight_tag() + self.use_activation_hooks = config.dynamic_use_activation_hooks def forward(self, x): - x_fp8 = self.cast_to_float8(x) + # cast x to float8_e4m3fn + if self.use_activation_hooks: + x_fp8 = x + else: + x_fp8 = self.cast_to_float8(x) + + # cast w to float8_e4m3fn if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast w_fp8 = self._w_fp8 else: @@ -59,7 +85,10 @@ def forward(self, x): y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) # Cast gradY to float8_e5m2 during backward - y = self.cast_to_float8e5m2_bw(y) + if self.use_activation_hooks: + pass + else: + y = self.cast_to_float8e5m2_bw(y) return y @@ -69,6 +98,7 @@ def add_weight_tag(self): self.weight._is_fp8_weight = True def cast_to_float8(self, inpt_tensor): + # TODO rename this function to clarify e4m3 scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn) return Float8Tensor.to_float8( inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate @@ -92,4 +122,10 @@ def from_float(cls, mod, emulate: bool = False): new_mod.bias = mod.bias new_mod.emulate = emulate new_mod.add_weight_tag() + + new_mod.use_activation_hooks = config.dynamic_use_activation_hooks + if new_mod.use_activation_hooks: + # install the hooks + new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook) + new_mod.register_full_backward_pre_hook(cast_dldy_to_float8_e5m2_pre_hook) return new_mod diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 392358f2..51800226 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -77,6 +77,7 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): @implements([aten.mm.default]) def float8_mm(aten_op, args, kwargs=None): + print('float8_mm', args) assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) a = args[0] b = args[1] diff --git a/test/test_bw_hook.py b/test/test_bw_hook.py new file mode 100644 index 00000000..40146cc9 --- /dev/null +++ b/test/test_bw_hook.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +class TestAutogradFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, tensor): + tensor = tensor + 1.0 + return tensor + + @staticmethod + def backward(ctx, gradY): + # prints a tensor filled with 0.123, as expected + print('gradY', gradY) + gradY = gradY + 1.0 + return gradY + +class M(nn.Module): + def forward(self, x): + return TestAutogradFunction.apply(x) + +m = M() + +def bw_pre_hook(module, go): + new_go = torch.empty_like(go[0]).fill_(0.123) + return (new_go,) + +m.register_full_backward_pre_hook(bw_pre_hook) + +x = torch.randn(2, 2).requires_grad_() +y = m(x) +y.sum().backward()