-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #95 from graphcore-research/pom-2024-09
pom-2024-09
- Loading branch information
Showing
19 changed files
with
322 additions
and
0 deletions.
There are no files selected for viewing
34 changes: 34 additions & 0 deletions
34
_posts/papers-of-the-month/2024-09/2024-09-30-proper-conditioning.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
--- | ||
title: "September Papers: Proper Conditioning" | ||
header: | ||
teaser: /assets/images/posts/2024-09/potm/twitter_card.png | ||
image: /assets/images/posts/2024-09/potm/twitter_card.png | ||
og_image: /assets/images/posts/2024-09/potm/twitter_card.png | ||
|
||
date: 2024-09-30T01:00:00-00:00 | ||
potm_year: 2024 | ||
potm_month: 9 | ||
|
||
layout: paper-summaries-layout | ||
category: "papers-of-the-month" | ||
toc: true | ||
toc_sticky: true | ||
toc_label: "Papers" | ||
toc_icon: "book" | ||
author.twitter: "GCResearchTeam" | ||
--- | ||
|
||
We're pleased to share four papers from different domains: LLM self-correction, FP8 training, generative crystals and optimisation. They are united, somewhat tenuously, by the importance of _proper conditioning_: | ||
|
||
1. DeepMind researchers explain how _conditioning on the wrong distribution_ during supervised fine-tuning for self-correction is harmful but can be overcome using RL. | ||
2. A novel Smooth-SwiGLU activation _"conditions" the numerics_ by inserting a scaling factor in just the right place, preventing late-training instability in FP8. | ||
3. The GenMS architecture generates crystal structures for materials _conditions on high-level textual and low-level structural information_ for high-quality generation. | ||
4. SOAP is an evolution of Shampoo, with conditioners in the name and _preconditioners forming the eigenbasis_ for optimisation. | ||
|
||
You can be the judge of how tenuous the connection is, but I'd encourage you to check out the summaries first or despite this. | ||
|
||
_I hope you enjoy these as much as we did. Tell us we're wrong; tell us we're right [@GCResearchTeam](https://x.com/GCResearchTeam)._ | ||
|
||
--- | ||
|
||
{% include paper-summaries.md %} |
49 changes: 49 additions & 0 deletions
49
_posts/papers-of-the-month/2024-09/papers/2024-09-27-GenMS.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
--- | ||
title: "Generative Hierarchical Materials Search" | ||
paper_authors: "Sherry Yang, et al." | ||
orgs: "Google DeepMind" | ||
paper_link: "https://www.arxiv.org/abs/2409.06762" | ||
tags: | ||
- LLMs | ||
- diffusion | ||
- GNNs | ||
- materials | ||
potm_year: 2024 | ||
potm_month: 9 | ||
paper_order: 3 | ||
image_dir: "/assets/images/posts/2024-09/potm/GenMS/" | ||
review_author: | ||
name: "Daniel Justus" | ||
link: "https://www.linkedin.com/in/daniel-justus/" | ||
hidden: true | ||
--- | ||
|
||
### The key idea | ||
|
||
In recent years, machine learning based methods have increasingly been applied to assist the discovery of novel or improved materials with certain desired properties. In this paper, the authors present GenMS, an end-to-end generative model for crystal structures from language instructions. To that end, GenMS combines an LLM to process the user input, a diffusion model to generate molecular structures, and a GNN to predict the structures' properties and select the best candidates. | ||
|
||
<img class="constrained_img_large" src="{{ page.image_dir | append: 'GenMS-pipeline.png' | relative_url }}" alt="End-to-end pipeline for the generation of crystal structures from language instructions."> | ||
|
||
### Their method | ||
|
||
The authors argue that data linking the properties of materials to their crystal structure exists at two different abstraction levels: high-level information is available as text, while lower-level structural information such as atom positions exists in crystal databases. To reflect this, the generative model is split into two components with the chemical formulae of candidate materials serving as intermediate representation: | ||
|
||
1. An LLM trained on materials science knowledge from sources such as textbooks is used to sample chemical formulae that satisfy the user's directions. Retrieval augmentation is used to gain additional information and the formulae of crystals from existing databases are provided in the context to avoid generating known crystals. | ||
2. A diffusion model trained on crystal structure databases then generates crystal structures from these formulae. To improve the efficiency of the diffusion model, a simple representation using the 3D position and atom number of each atom in the crystal is adopted instead of e.g. a graph. | ||
|
||
<img class="constrained_img_large" src="{{ page.image_dir | append: 'GenMS-diffusion.png' | relative_url }}" alt="Diffusion model for crystal structures."> | ||
|
||
As a final step, a pretrained GNN is used to predict the formation energy and potentially other properties of the generated crystal structures and rank them based on this result. | ||
|
||
During inference, a tree search is performed to identify low-energy structures that satisfy the natural language instructions. Here, the number of generated intermediate chemical formulae and crystal structures are hyperparameters to trade off compute cost for result quality. | ||
|
||
### Results | ||
|
||
The main baseline presented in the study is an LLM that is prompted to directly, i.e. without the chemical formulae as an intermediate representation, generate crystal structures in the form of crystal information files. GenMS significantly improves on this baseline in all investigated quality criteria. | ||
Furthermore, the authors demonstrate that the model follows simple prompts such as requesting a metal or a material that is not present in a given list. | ||
|
||
<img class="constrained_img_large" src="{{ page.image_dir | append: 'GenMS-results.png' | relative_url }}" alt="GenMS results compared to a prompted LLM."> | ||
|
||
### Takeaways | ||
|
||
The possibility of sampling materials based on natural language instructions in an end-to-end fashion is a promising direction for improving materials generation and making it more accessible. However, the authors acknowledge a few shortcomings that require further work. In particular, more specific user input (e.g. "generate a semiconductor"), the generation of more complex crystal structures and the inclusion of further criteria such as synthesizability of the generated material remain challenging. |
75 changes: 75 additions & 0 deletions
75
_posts/papers-of-the-month/2024-09/papers/2024-09-27-fp8_smooth_swiglu.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
--- | ||
title: "Scaling FP8 training to trillion-token LLMs" | ||
paper_authors: "Maxim Fishman, Brian Chmiel, et al." | ||
orgs: "Habana Labs, Technion" | ||
paper_link: "https://arxiv.org/abs/2409.12517" | ||
tags: | ||
- efficient-inference | ||
- quantisation | ||
potm_year: 2024 | ||
potm_month: 9 | ||
paper_order: 2 | ||
image_dir: "/assets/images/posts/2024-09/potm/fp8_smooth_swiglu/" | ||
review_author: | ||
name: "Paul Balanca" | ||
link: "https://www.linkedin.com/in/paulbalanca" | ||
hidden: true | ||
--- | ||
|
||
|
||
### The key idea | ||
|
||
Building upon recent literature on low-precision FP8 training, the authors investigate the FP8 training | ||
stability of trillion-token LLMs (a ~20-fold increase over previous published work). Uncovering a new form of | ||
critical instability, they present an improved *Smooth-SwiGLU* activation function which prevents activation | ||
spikes (outliers) from causing training divergence in LLMs. | ||
|
||
|
||
<img src="{{ page.image_dir | append: 'fp8-training-instable.png' | relative_url }}" class="constrained_img" alt="Training instability in FP8 due to SwiGLU."> | ||
<figcaption>Training instability in FP8 due to the SwiGLU activation function.</figcaption> | ||
|
||
|
||
### Background | ||
|
||
Machine learning researchers, especially in AI hardware companies, have been investigating for the last couple of years which | ||
8-bit floating formats are suitable for neural network training and inference. The literature on the subject converges towards | ||
the definition of two formats: **E4M3** and **E5M2**. The former is used to represent weights and activations, while the latter | ||
is used for gradients, which require a higher dynamic range. | ||
|
||
Due to the much smaller dynamic range compared to BF16 (which is commonly used in LLM training), FP8 LLM training requires ad-hoc | ||
per tensor scaling using data statistics (usually the absolute-max) in order to keep training stable. | ||
|
||
Most of the FP8 literature has focused on small to mid-scale experiments (at most 100B tokens training), and presented in this work, | ||
late-stage LLMs training also presents numerical stability challenges, with large outliers appearing in the transformer feed-forward layer. | ||
|
||
### Their method | ||
|
||
As presented in the figure above, instabilities appear in late FP8 training of large LLMs. In this work, the authors narrow down the issue | ||
to the quadratic form of the *SwiGLU* activation function when combined with weight alignment. Experimental training data shows that | ||
large outliers appear more often during late training due to the correlation between `w1` and `w2` SwiGLU weights (which are uncorrelated initially). | ||
|
||
<img src="{{ page.image_dir | append: 'fp8-swiglu-hist.png' | relative_url }}" class="constrained_img_large" alt="SwiGLU weights correlation and outliers."> | ||
<figcaption>SwiGLU weights correlation and outliers.</figcaption> | ||
|
||
These outliers will lead to underflow or overflow during FP8 quantization when combined with delayed scaling, as the latter technique relies on | ||
the previous batch statistics for optimal hardware usage. In order to circumvent this issue, the authors introduce a new *smooth SwiGLU* activation | ||
function which incorporates channel scaling correction prior to FP8 casting, i.e.: | ||
|
||
<img src="{{ page.image_dir | append: 'fp8-smooth-swiglu.png' | relative_url }}" class="constrained_img_large" alt="Smooth-SwiGLU channel scaling."> | ||
|
||
As presented by the authors, channel max-scaling is well suited to hardware accelerator as each chunk of data can be treated in parallel, and the resulting | ||
rescaling can be fused into the FP8 quantization of input activations $x$ and weights $w_3$ (third MLP layer): | ||
|
||
<img src="{{ page.image_dir | append: 'fp8-smooth-swiglu2.png' | relative_url }}" alt="Smooth-SwiGLU definition."> | ||
<figcaption>Smooth-SwiGLU definition.</figcaption> | ||
|
||
We note that the introduction of the *smooth-SwiGLU* activation preserves the overall FFN definition (from a mathematical point of view): additional channel scaling factors are compensated later in the network in the third MLP layer. We at Graphcore Research have proposed a similar approach in our recent [Scalify](https://github.com/graphcore-research/jax-scalify/) work: incorporating additional scaling in neural networks to improve numerical stability while keeping the same model definition. | ||
|
||
### Results | ||
|
||
Training experiments on a 7B Llama 2 model show the improved stability of FP8 LLM training when using the smooth-SwiGLU activation: training loss as well as zero-shot downstream tasks match the BF16 baseline. The use of smooth-SwiGLU only leads to a small drop in FP8 training acceleration, from 37% to 34%, due to the cost of channel rescaling. | ||
|
||
<img src="{{ page.image_dir | append: 'fp8-smooth-swiglu-training.png' | relative_url }}" class="constrained_img" alt="FP8 LLM training with Smooth-SwiGLU."> | ||
<figcaption>FP8 LLM training with Smooth-SwiGLU.</figcaption> | ||
|
||
The authors also demonstrate that the FP8 E5M2 format can be used for storing the Adam optimizer second moment (as presented in previous works, the first moment can be represented using E4M3). |
98 changes: 98 additions & 0 deletions
98
_posts/papers-of-the-month/2024-09/papers/2024-09-27-llm-correction-via-rl.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
--- | ||
title: "Training Language Models to Self-Correct via Reinforcement Learning" | ||
paper_authors: "Aviral Kumar, Vincent Zhuang, et al." | ||
orgs: "Google DeepMind" | ||
paper_link: "https://arxiv.org/abs/2409.12917" | ||
tags: | ||
- self-correction | ||
- reinforcement-learning | ||
- fine-tuning | ||
- LLMs | ||
potm_year: 2024 | ||
potm_month: 9 | ||
paper_order: 1 | ||
image_dir: "/assets/images/posts/2024-09/potm/llm-correction-via-rl/" | ||
review_author: | ||
name: "Charlie Blake" | ||
link: "https://x.com/thecharlieblake" | ||
hidden: true | ||
--- | ||
|
||
### The key idea | ||
|
||
Users of LLMs will be aware that sometimes they can recognise and correct their own mistakes. This prompts the question: if the model has the capability to identify some of its own failures, can we leverage this to improve the model? | ||
|
||
This is easier said than done. This paper shows that supervised fine-tuning (SFT) --- the dominant post-training approach for LLMs --- has some inevitable failure modes when trying to teach a model to self-correct. What's needed, and what they demonstrate, is that an RL-based approach can prevail. | ||
|
||
This is significant: _true RL_ has only just broken into the LLM training space, in the form of [OpenAI's o1 model](https://openai.com/index/learning-to-reason-with-llms/), but few details have been released. This work presents a significant step towards realising the benefits of RL in helping language models to reason better. | ||
|
||
<img src="{{ page.image_dir | append: 'figure-6.png' | relative_url }}" alt="An overview of the method, named SCoRe. Supervised approaches can lead to distributional mismatch or never-correcting behaviours. SCoRe addresses this via a 2-stage RL process, where stage 1 encourages the model to produce effective corrections and stage 2 focuses on both initial response and correction."> | ||
<figcaption>An overview of the method, named SCoRe. Supervised approaches can lead to distributional mismatch or never-correcting behaviours. SCoRe addresses this via a 2-stage RL process, where stage 1 encourages the model to produce effective corrections and stage 2 focuses on both initial response and correction.</figcaption> | ||
|
||
### Background | ||
|
||
The most straightforward approach to solving the self-correction problem is simply: | ||
|
||
1. Take a dataset of question-answer pairs for some reasoning task | ||
2. For each, prompt the model to generate a solution | ||
3. Evaluate each and remove those solutions which are _correct_ | ||
3. Then prompt the model to generate a correction to the incorrect solution | ||
4. Evaluate the final solutions, and now filter out the _incorrect_ ones | ||
5. Take this dataset of 2-stage "corrected" answers and train the model on it | ||
|
||
This is the basis of the [STaR method](https://arxiv.org/abs/2203.14465), which the authors use as a baseline, alongside PairSFT, which works similarly but uses arbitrary pairs of incorrect-correct responses to a given prompt as training data. | ||
|
||
The authors test these methods and see the following: | ||
|
||
<img src="{{ page.image_dir | append: 'table-1.png' | relative_url }}" alt="An evaluation of the STaR and PairSFT baselines shows that neither is able to offer significant improvements."> | ||
|
||
STaR slightly improves the initial attempt, but is poor at correcting --- so much so that it tends to make answers worse, not better! Pair-SFT offers a modest accuracy improvement, though this is largely down to a drop in the value of the final column, which indicates the fraction of correct responses the model ruins via wrong "corrections". So in summary: the only improvement we really see is the model learning to be much more cautious in correcting itself. | ||
|
||
They trace these difficulties down to two problems: | ||
|
||
1. The model tends towards a **minimal edit** policy, where it tries to change as little as possible to avoid degrading the original response. | ||
2. The model is trained on data from its original distribution over responses, yet training causes this distribution to change, leading to **distribution mismatch**. | ||
|
||
### Their method | ||
|
||
The two-stage RL-based method they design aims to target the problems outlined in turn. | ||
|
||
**Stage 1:** The first stage uses RL to maximise the following objective: | ||
|
||
<div> | ||
$$ | ||
\mathcal{E}\left[ \hat{r}(\mathbf{y}_2, \mathbf{y}^*) - \alpha D_{KL} \left( \pi_{\theta}(\cdot \| \mathbf{x}_1) \| \pi_{\text{ref}}(\cdot | \mathbf{x}_1) \right) \right], | ||
$$ | ||
</div> | ||
|
||
Here $\hat{r} (\mathbf{y_2}, \mathbf{y^*})$ is some "correctness" function that acts as a reward, which crucially is based on $\mathbf{y_2}$, the model's _second_ attempt at the problem. The KL term acts on the _first_ attempt, encouraging the model to keep its first guess the same as the original ("reference") model. | ||
|
||
We can see from this that the aim is to encourage the model to learn strong correction behaviour, by fixing the first attempt and optimizing just the second (approximately). This addresses the minimal edit problem. | ||
|
||
**Stage 2:** Having encouraged strong correction in stage 1, the full problem is addressed in stage 2, which maximises: | ||
|
||
<div> | ||
$$ | ||
\mathcal{E}\left[ \sum_{i=1}^{2} \hat{r}(\mathbf{y}_i, \mathbf{y}^*) - \beta D_{KL} \left( \pi_{\theta}(\cdot | \mathbf{x}_i) \| \pi_{\text{ref}}(\cdot | \mathbf{x}_i) \right) \right] | ||
$$ | ||
</div> | ||
|
||
Here the RL objective is over both attempts, with a weaker KL penalty over both acting as a mild regulariser. A reward-shaping step is also used here to up-weight examples where incorrect first attempts are successfully corrected. | ||
|
||
The key difference between this and SFT is that the data used to update the model is always generated by the current model. This avoids the distribution mismatch problem. | ||
|
||
### Results | ||
|
||
In short, it works. Results are good on maths problems, and even better on coding tasks: | ||
|
||
<img src="{{ page.image_dir | append: 'table-4.png' | relative_url }}" alt="An evaluation of the base model, Self-refine and PairSFT versus SCoRe on the HumanEval coding benchmark. SCoRe is much better at correcting itself than other methods."> | ||
|
||
The first-attempt accuracy is slightly degraded, but the second attempt is substantially better than any other attempt by other methods. The main reason for this is shown in the second-to-last column: a large increase in incorrect answers becoming correct, which is the key objective. | ||
|
||
The paper shows several other evaluations and ablations, making a strong case for the method. | ||
|
||
### Takeaways | ||
|
||
This paper makes a compelling case for why supervised fine-tuning is limited as a post-training procedure, and for some problems (such as self-correction), some kind of on-policy RL is required. Carefully designed objectives are required to make this work, but it appears to significantly boost a model's ability to reason at inference time. | ||
|
||
This is just the start. The authors consider a fairly simple problem setting: a single correction attempt on a zero-shot answer, with no supervision as to the source of error. One could imagine a similar approach with many correction attempts, possibly on chain-of-thought responses, and with more granular feedback. This promises to be a significant direction of future LLM research, with significant computational and algorithmic implications. |
Oops, something went wrong.