Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to support u-muP, as the new default behaviour #58

Merged
merged 27 commits into from
Jul 24, 2024
Merged

Conversation

DouglasOrr
Copy link
Collaborator

@DouglasOrr DouglasOrr commented Jul 16, 2024

Demo notebook.

README header.

Changes:

  • Add mults to nonlinear U.* ops and uu.* modules
  • Add U.readout_linear, uu.ReadoutLinear to apply 1/fan_in scaling
  • Add U.rms_norm, uu.RMSNorm, U.silu, uu.SiLU, U.silu_glu, U.mse_loss
  • Add uu.Parameter, uu.optim.* to provide LR scaling
  • Add uu.Trunk, uu.TransformerStack to set mup_scaling_depth and (for TransformerStack) apply appropriate residual mults
  • Add demo notebook examples/demo.ipynb
  • Change the default constraint from "gmean" -> "to_output_scale"
  • Change default uu.Linear to bias=False & uu.*Norm to elementwise_affine=False
  • Change U.softmax and U.scaled_dot_product_attention scaling rules to use the empirical fit
  • Change definition of tau in residual_split/residual_add
  • Change uu.MHSA to use U.scaled_dot_product_attention
  • Change uu.MLP to use SwiGLU
  • Change uu.TransformerLayer, uu.TransformerDecoder to use RMSNorm and various other tweaks
  • Update top of README (other docs updates are lagging)
  • +probably more

@DouglasOrr DouglasOrr requested a review from lyprince July 23, 2024 06:55
@DouglasOrr DouglasOrr marked this pull request as ready for review July 23, 2024 06:57
@DouglasOrr
Copy link
Collaborator Author

@lyprince, thank you for agreeing to review!

@thecharlieblake FYI

This is still WIP, but (I think/hope) the major changes are in.

unit_scaling/_modules.py Outdated Show resolved Hide resolved
)
self.is_causal = is_causal
self.mult = mult
self.linear_qkv = Linear(hidden_size, 3 * hidden_size, constraint=constraint)
Copy link

@lyprince lyprince Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without a linear specialisation (like linear_readout), this will have an extra sqrt(3) factor in the scale.

Our default implementation does not fuse the qkv matmul accordingly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, I wondered about this; the cheeky thing I was thinking is that when using the new-default constraint "to_output_scale" it shouldn't matter, as the scale just depends on fan_in. But maybe this is a bit cheeky as it can be overridden.

Copy link

@lyprince lyprince Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the scale just depends on fan_in

Won't the output scale of this op be 3 * fan_in rather than fan_in though? I see in the demo notebook that attn_qkv.weight.grad.std = 0.62, which is pretty close to 1/sqrt(3) = 0.58.

unit_scaling/_modules.py Outdated Show resolved Hide resolved
unit_scaling/_modules.py Outdated Show resolved Hide resolved
Copy link

@lyprince lyprince left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I've done my best to check for correctness. I have some uncertainty around residual scaling. A test for unit scale preserved across depth with the residual scaling scheme would give me more confidence.

Other than that, my main comments are around

  1. Documentation that points to relevant parts of paper for justifying design choices / changes.
  2. Doubts over whether use of nn.Sequential is the right choice for keeping track of depth, given common patterns for implementing mask and positional embeddings.
  3. Clarification on whether it is intentional to fuse qkv projection in MHSA.
  4. Testing compatibility of Parameter class with torch.compile (does FakeTensor trigger deepcopy failure).

"# Config & helpers\n",
"torch.backends.cuda.matmul.allow_tf32 = True\n",
"torch.backends.cudnn.allow_tf32 = True\n",
"def show_layer_stats(layer: nn.Module, input_shape: Tuple[int, ...]) -> None:\n",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Demo notebook looks good! Only thing that broke my attention was searching for show_layer_stats. Was easy to miss when colocated with imports. My preference would be to have it defined next to first use, or at least be in its own cell.

@DouglasOrr DouglasOrr merged commit 087133f into main Jul 24, 2024
1 check passed
@DouglasOrr DouglasOrr deleted the umup-updates branch July 24, 2024 15:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants