Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch layer kernel implementation in the config #35

Draft
wants to merge 17 commits into
base: develop
Choose a base branch
from

Conversation

cathalobrien
Copy link

@cathalobrien cathalobrien commented Sep 11, 2024

Describe your changes

This PR makes it possible to switch the implementation of Linear and LayerNorm kernels in the config.

At the moment we use torch.NN implementation for many layers in Anemoi model e.g. torch.nn.layerNorm, torch.NN.linear. This has the advantage of being available out of the box with torch and portable to many systems (CPU, AMD and Nvidia GPUs). However, other layer implementations might be more efficient for certain hardware, or use different algorithms we want to explore (RMSNorm is an alternate implementation of LayerNorm for instance).
These might only run on certain systems (e.g. Nvidias transformer_engine.pytorch provides layer implementations optimized for their GPUs).
Therefore, we'd like to be able to flexibly take advantage of these faster kernels when they're available without losing the ability to fall back to torch.nn for resiliency.

This PR adds the following block to config/model/.yaml:

  layer_kernels:
    LayerNorm:
      #_target_: "transformer_engine.pytorch.LayerNorm"
      _target_: "liger_kernel.transformers.rms_norm.LigerRMSNorm"
      #_target_: "torch.nn.LayerNorm" #the default PyTorch implementation
      _partial_: True
      #Any arguments to your chosen function go here e.g.
      #bias: False
    Linear:
      #_target_: "transformer_engine.pytorch.Linear"
      _target_: "torch.nn.Linear"
      _partial_: True

In the future, this syntax could be extended to replace other layers (e.g. mlp) if required.

The calls to torch.nn are then replaced with

- self.layer_norm1 = nn.LayerNorm(num_channels)
+ LayerNorm=layer_kernels['LayerNorm']
+ self.layer_norm1 = LayerNorm(num_channels)

You can pass any parameters to your new kernels in the config file, after "partial : True". Hydra tries to load the desired kernel in "models/encoder_processor_decoder.py". If the desired library isnt available, torch currently will fall back to torch.nn..

        for kernel in self.layer_kernels:
            kernel_entry=self.layer_kernels[kernel]
            try:
                instantiate(kernel_entry)
            except InstantiationException:
                LOGGER.info(f"{kernel_entry['_target_']} not availible! falling back to torch.nn.{kernel}")
                #config.model.layer_kernels[kernel]["_target_"]=f"torch.nn.{kernel}"
                self.layer_kernels[kernel] = DotDict({'_target_': f"torch.nn.{kernel}", '_partial_': True}) #replace the entry, to remove any args passed to the orginal kernel
        LOGGER.debug(f"{self.layer_kernels=}")

I am in two minds about this. On the one hand, it helps ensure you run if you are missing a library (which might be nessecary when doing inference in a different enviroment to where the model was trained). But on the other hand, maybe it would be better to loudly fail when the user is requesting an uninstalled library be used. Otherwise, a user could be under the impression they are using an optimised kernel when they are not. Also it feels like poor programming practice to branch on Exceptions like I am doing here. Open to suggestions on this.

This feature makes it easy to try out new kernels in an end-to-end machine learning run, rather then simply doing a standalone kernel benchmark. Using this feature I was able to trial the RMSNorm implementation of LayerNorm from Liger Kernel, in 2 lines (pip install liger_kernel; vim config/model/transformer.yaml) and I saw ~10% speedup.

LayerNorm run_training_batch time (s) training_avg_throughput (iter/s)
torch.nn.LayerNorm 0.93532 0.88574
liger_kernel.transformers.rms_norm.LigerRMSNorm 0.82982 0.99186

Type of change

Please delete options that are not relevant.

New feature (non-breaking change which adds functionality)

This change requires a documentation update

Checklist before requesting a review

  • I have performed a self-review of my code
  • My code follows the style guidelines of this project
  • I have commented my code, particularly in hard-to-understand areas
  • I have updated the documentation and docstrings to reflect the changes
  • have added tests that prove my fix is effective or that my feature works
  • I have ensured that the code is still pip-installable after the changes and runs
  • I have not introduced new dependencies in the inference partion of the model
  • I have ran this on single GPU
  • I have ran this on multi-GPU or multi-node
  • I have ran this to work on LUMI (or made sure the changes work independently.)
  • I have ran the Benchmark Profiler against the old version of the code

@FussyDuck
Copy link

FussyDuck commented Sep 11, 2024

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
1 out of 2 committers have signed the CLA.

✅ cathalobrien
❌ Cathal Liam O Brien


Cathal Liam O Brien seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@clessig
Copy link

clessig commented Sep 11, 2024

I think silently switching to a different implementation is dangerous. As we see with the attention, there might be differences in the implementation that the user should be aware of and think about when selecting a different backend.

@cathalobrien
Copy link
Author

LOGGER.info(f"{kernel_entry['_target_']} not availible! falling back to torch.nn.{kernel}")

I agree, good analogy with attention. At the moment there's a warning, but stdout is easily missed.

However, we have to make sure inference is still possible without any additional libraries that might have been used during training. This could be done by resetting the 'layer_kernels' config entry to "torch.nn" during inference. but at the moment there's apparently no easy way to tell if you're in inference or training, so this requires some thought

src/anemoi/models/layers/attention.py Outdated Show resolved Hide resolved
src/anemoi/models/layers/block.py Outdated Show resolved Hide resolved
mlp1.append(act_func())
mlp1.append(nn.Linear(hidden_dim, out_features))
mlp1.append(Linear(hidden_dim, out_features))

if final_activation:
mlp1.append(act_func())

if layer_norm:
mlp1.append(AutocastLayerNorm(out_features))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still want the AutocastLayerNorm here? should this be replaced by a LayerNorm?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that where we landed from the Nvidia meeting way back? I can do some tests

src/anemoi/models/models/encoder_processor_decoder.py Outdated Show resolved Hide resolved
src/anemoi/models/models/encoder_processor_decoder.py Outdated Show resolved Hide resolved
src/anemoi/models/models/encoder_processor_decoder.py Outdated Show resolved Hide resolved
src/anemoi/models/models/encoder_processor_decoder.py Outdated Show resolved Hide resolved
src/anemoi/models/models/encoder_processor_decoder.py Outdated Show resolved Hide resolved
@@ -69,6 +70,31 @@ def __init__(

self.num_channels = config.model.num_channels

# If self.layer_kernels entry is missing from the config, use torch.nn by default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the config? i'm not seeing it in this PR

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The configs are part of anemoi training now, so I guess i'll just update the docs for anemoi-training to say this feature exists and show some examples. Otherwise I can make a small PR to anemoi-training with an updated config

src/anemoi/models/models/encoder_processor_decoder.py Outdated Show resolved Hide resolved
@cathalobrien
Copy link
Author

LOGGER.info(f"{kernel_entry['_target_']} not availible! falling back to torch.nn.{kernel}")

I agree, good analogy with attention. At the moment there's a warning, but stdout is easily missed.

However, we have to make sure inference is still possible without any additional libraries that might have been used during training. This could be done by resetting the 'layer_kernels' config entry to "torch.nn" during inference. but at the moment there's apparently no easy way to tell if you're in inference or training, so this requires some thought

I'll make a PR to ai_models, where i reset layer_kernels to torch.nn if thay've been set. seems like the most straightforward way to handle inference

@JesperDramsch
Copy link
Member

Open to a discussion, but I think we shouldn't have "default fallbacks" for these types of things, if the code fails to instantiate from the provided config.

When someone sbatch submits their job with a specific experiment, it should probably be obvious that the config wasn't valid, but with these fallbacks, it just "runs through" anyways and you have to monitor your logs if anything went wrong with your experiment due to a misconfiguration or wrong environment.

@cathalobrien
Copy link
Author

cathalobrien commented Oct 9, 2024

Open to a discussion, but I think we shouldn't have "default fallbacks" for these types of things, if the code fails to instantiate from the provided config.

When someone sbatch submits their job with a specific experiment, it should probably be obvious that the config wasn't valid, but with these fallbacks, it just "runs through" anyways and you have to monitor your logs if anything went wrong with your experiment due to a misconfiguration or wrong environment.

The current behaviour is:

  • if the entry is missing from the config -> use torch.nn by default. This maintains backwards compatibility with older config files

  • if the config entry exists but the library it describes cant be loaded -> throw an error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants