-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updates to support u-muP, as the new default behaviour #58
Conversation
@lyprince, thank you for agreeing to review! @thecharlieblake FYI This is still WIP, but (I think/hope) the major changes are in. |
) | ||
self.is_causal = is_causal | ||
self.mult = mult | ||
self.linear_qkv = Linear(hidden_size, 3 * hidden_size, constraint=constraint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without a linear
specialisation (like linear_readout
), this will have an extra sqrt(3)
factor in the scale.
Our default implementation does not fuse the qkv matmul accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes, I wondered about this; the cheeky thing I was thinking is that when using the new-default constraint "to_output_scale" it shouldn't matter, as the scale just depends on fan_in. But maybe this is a bit cheeky as it can be overridden.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the scale just depends on fan_in
Won't the output scale of this op be 3 * fan_in
rather than fan_in
though? I see in the demo notebook that attn_qkv.weight.grad.std = 0.62
, which is pretty close to 1/sqrt(3) = 0.58
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! I've done my best to check for correctness. I have some uncertainty around residual scaling. A test for unit scale preserved across depth with the residual scaling scheme would give me more confidence.
Other than that, my main comments are around
- Documentation that points to relevant parts of paper for justifying design choices / changes.
- Doubts over whether use of
nn.Sequential
is the right choice for keeping track of depth, given common patterns for implementing mask and positional embeddings. - Clarification on whether it is intentional to fuse qkv projection in MHSA.
- Testing compatibility of Parameter class with
torch.compile
(does FakeTensor triggerdeepcopy
failure).
examples/demo.ipynb
Outdated
"# Config & helpers\n", | ||
"torch.backends.cuda.matmul.allow_tf32 = True\n", | ||
"torch.backends.cudnn.allow_tf32 = True\n", | ||
"def show_layer_stats(layer: nn.Module, input_shape: Tuple[int, ...]) -> None:\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Demo notebook looks good! Only thing that broke my attention was searching for show_layer_stats
. Was easy to miss when colocated with imports. My preference would be to have it defined next to first use, or at least be in its own cell.
Demo notebook.
README header.
Changes:
U.*
ops anduu.*
modulesU.readout_linear
,uu.ReadoutLinear
to apply 1/fan_in scalingU.rms_norm
,uu.RMSNorm
,U.silu
,uu.SiLU
,U.silu_glu
,U.mse_loss
uu.Parameter
,uu.optim.*
to provide LR scalinguu.Trunk
,uu.TransformerStack
to setmup_scaling_depth
and (forTransformerStack
) apply appropriate residual multsexamples/demo.ipynb
uu.Linear
tobias=False
&uu.*Norm
toelementwise_affine=False
U.softmax
andU.scaled_dot_product_attention
scaling rules to use the empirical fittau
inresidual_split
/residual_add
uu.MHSA
to useU.scaled_dot_product_attention
uu.MLP
to use SwiGLUuu.TransformerLayer
,uu.TransformerDecoder
to useRMSNorm
and various other tweaks