diff --git a/.pin/constraints-rocm-torch.txt b/.pin/constraints-rocm-torch.txt
index 4fe6ae9d..0cf3fcbc 100644
--- a/.pin/constraints-rocm-torch.txt
+++ b/.pin/constraints-rocm-torch.txt
@@ -368,7 +368,7 @@ pytinyrenderer==0.0.14
     # via brax
 pytorch-lightning==2.4.0
     # via lightning
-pytorch-triton-rocm==3.0.0
+pytorch-triton-rocm==3.1.0
     # via torch
 pytz==2024.1
     # via pandas
@@ -437,7 +437,7 @@ six==1.16.0
     #   tensorboard
 submitit==1.5.1
     # via -r benchmarks/dinov2/requirements.in
-sympy==1.13.2
+sympy==1.13.1
     # via torch
 tabulate==0.9.0
     # via fvcore
@@ -461,7 +461,7 @@ tokenizers==0.19.1
     # via transformers
 toolz==0.12.1
     # via chex
-torch==2.4.0+rocm6.0
+torch==2.5.1+rocm6.1
     # via
     #   -r benchmarks/brax/requirements.in
     #   -r benchmarks/dinov2/requirements.in
@@ -483,7 +483,7 @@ torch==2.4.0+rocm6.0
     #   torchmetrics
     #   torchvision
     #   xformers
-torchao==0.3.1
+torchao==0.6.1
     # via torchtune
 torchcompat==1.1.4
     # via
@@ -498,9 +498,10 @@ torchmetrics==1.4.1
     #   -r benchmarks/dinov2/requirements.in
     #   lightning
     #   pytorch-lightning
-torchtune==0.2.1
+#torchtune==0.3.1
+git+https://github.com/pytorch/torchtune.git@e1caa9f82fea24d728f9b244a9dd1957f5ed7465
     # via -r benchmarks/llm/requirements.in
-torchvision==0.19.0+rocm6.0
+torchvision==0.20.0+rocm6.1
     # via
     #   -r benchmarks/diffusion/requirements.in
     #   -r benchmarks/dinov2/requirements.in
@@ -585,7 +586,7 @@ werkzeug==3.0.3
     # via
     #   flask
     #   tensorboard
-xformers==0.0.27.post2
+xformers==0.0.28.post3 
     # via -r benchmarks/dinov2/requirements.in
 xxhash==3.5.0
     # via datasets
diff --git a/benchmarks/llm/configs/llama3_70B_full.yaml b/benchmarks/llm/configs/llama3_70B_full.yaml
index 28463631..6e0ac45a 100644
--- a/benchmarks/llm/configs/llama3_70B_full.yaml
+++ b/benchmarks/llm/configs/llama3_70B_full.yaml
@@ -73,7 +73,7 @@ checkpointer:
 resume_from_checkpoint: False
 
 # Fine-tuning arguments
-batch_size: 2
+batch_size: 1 #2
 epochs: 3
 
 optimizer:
diff --git a/benchmarks/llm/configs/llama3_70B_lora.yaml b/benchmarks/llm/configs/llama3_70B_lora.yaml
index 37a10499..6259cc6d 100644
--- a/benchmarks/llm/configs/llama3_70B_lora.yaml
+++ b/benchmarks/llm/configs/llama3_70B_lora.yaml
@@ -14,8 +14,8 @@ model:
   lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
   apply_lora_to_mlp: False
   apply_lora_to_output: False
-  lora_rank: 16
-  lora_alpha: 32
+  lora_rank: 8 #16
+  lora_alpha: 16 #32
 
 tokenizer:
   _component_: torchtune.models.llama3.llama3_tokenizer
@@ -68,7 +68,7 @@ dataset:
   _component_: torchtune.datasets.alpaca_dataset
 seed: null
 shuffle: True
-batch_size: 2
+batch_size: 1 #2
 
 # Optimizer and Scheduler
 optimizer:
diff --git a/benchmarks/llm/recipes/full_finetune_single_device.py b/benchmarks/llm/recipes/full_finetune_single_device.py
index 98322579..c23b0747 100755
--- a/benchmarks/llm/recipes/full_finetune_single_device.py
+++ b/benchmarks/llm/recipes/full_finetune_single_device.py
@@ -1,4 +1,12 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
+# #  As of November 2024, the development of torchrun is very rapid.
+# This is the recipe based on torchrun recipe git commit e137afe (post release 0.3.1)
+# https://github.com/pytorch/torchtune/blob/7bfb3336446f0d874ab5d4595249839b735b7076/recipes/full_finetune_single_device.py
+
+# Torchtune 0.2.1 recipe with device instrumenation (c) Mila
+# https://github.com/mila-iqia/milabench/blob/a60a3aae21e87e46bcce403620a3f56c12878554/benchmarks/llm/recipes/full_finetune_single_device.py
+
+# The instrumentation edits (c) AMD 
 
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # All rights reserved.
diff --git a/benchmarks/llm/recipes/lora_finetune_distributed.py b/benchmarks/llm/recipes/lora_finetune_distributed.py
index 00ec009f..9ea085d4 100755
--- a/benchmarks/llm/recipes/lora_finetune_distributed.py
+++ b/benchmarks/llm/recipes/lora_finetune_distributed.py
@@ -1,3 +1,5 @@
+#!/usr/bin/env python3
+
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # All rights reserved.
 #
@@ -270,7 +272,6 @@ def setup(self, cfg: DictConfig) -> None:
 
         checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
         self._compile = cfg.get("compile", False)
-
         self._model = self._setup_model(
             cfg_model=cfg.model,
             enable_activation_checkpointing=self._enable_activation_checkpointing,
@@ -453,15 +454,33 @@ def _setup_model(
         with training.set_default_dtype(self._dtype), torch.device("meta"):
             model = config.instantiate(cfg_model)
 
+        if self._is_rank_zero:
+            log.info(
+                "model instantiated based on config on Rank 0 ..."
+            )
+
+
         set_trainable_params(model, get_adapter_params(model))
 
+        if self._is_rank_zero:
+            log.info(
+                "trainable parameters set..."
+            )
+
+
         if self._compile:
+            if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"compiling..." )
+
             training.compile_model(model, verbose=self._is_rank_zero)
+            if self._is_rank_zero: log.info( str(self._is_rank_zero)+ " " + "done compiling" )
+
 
         if enable_activation_checkpointing:
+            if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"setting activation checkpointing" )
             training.set_activation_checkpointing(
                 model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
             )
+            if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done setting activation checkpointing" )
 
         # For FSDP sharding
         fsdp_shard_conditions = [
@@ -470,6 +489,9 @@ def _setup_model(
                 names_to_match=custom_sharded_layers,
             )
         ]
+
+        if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"sharding model" )
+
         training.shard_model(
             model=model,
             shard_conditions=fsdp_shard_conditions,
@@ -477,7 +499,11 @@ def _setup_model(
             reshard_after_forward=reshard_after_forward,
         )
 
+        if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done sharding model" )
+
         if lora_weights_state_dict:
+            if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"loading state dict" )
+
             lora_missing, lora_unexpected = training.load_from_full_model_state_dict(
                 model,
                 lora_weights_state_dict,
@@ -485,25 +511,32 @@ def _setup_model(
                 self._is_rank_zero,
                 cpu_offload=fsdp_cpu_offload,
             )
+            if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"loaded state dict" )
+
         else:
             lora_missing, lora_unexpected = None, None
 
         # Initialize LoRA params and RoPE buffers
         with training.set_default_dtype(self._dtype), self._device:
             lora_device = "cpu" if fsdp_cpu_offload else self._device
-            for m in model.modules():
+            for i,m in enumerate(model.modules()):
                 if (
                     isinstance(m, LoRALinear) or isinstance(m, DoRALinear)
                 ) and not lora_weights_state_dict:
+
                     # lora may not be covered in state dict
                     # if finetune for the 1st time
                     m.lora_a.to_empty(device=lora_device)
                     m.lora_b.to_empty(device=lora_device)
                     m.initialize_parameters()
+
                 # RoPE is not covered in state dict
                 if hasattr(m, "rope_init"):
                     m.rope_init()
 
+        
+        if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"load from full model state dict" )
+
         base_missing, base_unexpected = training.load_from_full_model_state_dict(
             model,
             base_model_state_dict,
@@ -515,9 +548,15 @@ def _setup_model(
         for m in model.modules():
             if hasattr(m, "initialize_dora_magnitude"):
                 is_dora = True
+                if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"init dora" )
                 m.initialize_dora_magnitude()
+                if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done init dora" )
+                
         if is_dora:
+            if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"load dora" )
             load_dora_magnitudes(model)
+            if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done load dora" )            
+        if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"validate missing" )
         validate_missing_and_unexpected_for_lora(
             lora_attn_modules=self._lora_attn_modules,
             apply_lora_to_mlp=self._apply_lora_to_mlp,
@@ -527,14 +566,22 @@ def _setup_model(
             lora_missing=lora_missing,
             lora_unexpected=lora_unexpected,
         )
+        if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done validate missing" )
+
+        if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"validate no params" )
+
         # Ensure no params and buffers are on meta device
         training.validate_no_params_on_meta_device(model)
 
+        if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done validate no params" )
+
+
+        if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"activation offload" )
         # activation offloading
         self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
             model, enable_activation_offloading
         )
-
+        if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done activation offload" )
         # log
         if self._is_rank_zero:
             log.info(
diff --git a/benchmarks/llm/recipes/lora_finetune_single_device.py b/benchmarks/llm/recipes/lora_finetune_single_device.py
index 3d5b6895..8bd4bd1b 100755
--- a/benchmarks/llm/recipes/lora_finetune_single_device.py
+++ b/benchmarks/llm/recipes/lora_finetune_single_device.py
@@ -1,5 +1,12 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
+#  As of November 2024, the development of torchrun is very rapid.
+# This is the recipe based on torchrun recipe git commit e137afe (post release 0.3.1)
+# https://github.com/pytorch/torchtune/blob/7bfb3336446f0d874ab5d4595249839b735b7076/recipes/lora_finetune_single_device.py
 
+# Torchtune 0.2.1 recipe with device instrumenation (c) Mila
+# https://github.com/mila-iqia/milabench/blob/a60a3aae21e87e46bcce403620a3f56c12878554/benchmarks/llm/recipes/lora_finetune_single_device.py
+
+# The instrumentation edits (c) AMD 
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # All rights reserved.
 #
diff --git a/benchmarks/llm/recipes/qat_distributed.py b/benchmarks/llm/recipes/qat_distributed.py
index 21143383..914fd93b 100755
--- a/benchmarks/llm/recipes/qat_distributed.py
+++ b/benchmarks/llm/recipes/qat_distributed.py
@@ -1,4 +1,12 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
+# #  As of November 2024, the development of torchrun is very rapid.
+# This is the recipe based on torchrun recipe git commit  e137afe (post release 0.3.1)
+# https://github.com/pytorch/torchtune/blob/7bfb3336446f0d874ab5d4595249839b735b7076/recipes/ppo_full_finetune_single_device.py
+
+# Torchtune 0.2.1 recipe with device instrumenation (c) Mila
+# https://github.com/mila-iqia/milabench/blob/a60a3aae21e87e46bcce403620a3f56c12878554/benchmarks/llm/recipes/qat_distributed.py
+
+# The instrumentation edits (c) AMD
 
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # All rights reserved.
@@ -6,36 +14,31 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
+import os
 import sys
 import time
 
 from functools import partial
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
 from warnings import warn
 
 import torch
 from omegaconf import DictConfig, ListConfig
 
 from torch import nn
-from torch.distributed import init_process_group
-from torch.distributed.fsdp import (
-    CPUOffload,
-    FullOptimStateDictConfig,
-    FullStateDictConfig,
-    FullyShardedDataParallel as FSDP,
-    StateDictType,
-)
+from torch.distributed import destroy_process_group, init_process_group
+
 from torch.optim import Optimizer
 from torch.utils.data import DataLoader, DistributedSampler
-
-from torchtune import config, modules, utils
+from torchtune import config, modules, training, utils
+from torchtune.data import padded_collate_packed, padded_collate_sft
 from torchtune.datasets import ConcatDataset
 from torchtune.recipe_interfaces import FTRecipeInterface
-from torchtune.utils.activations import apply_selective_activation_checkpointing
+from torchtune.training import DummyProfiler, PROFILER_KEY
+from torchtune.training.activations import apply_selective_activation_checkpointing
 
 from tqdm import tqdm
 
-
 log = utils.get_logger("DEBUG")
 
 
@@ -56,10 +59,13 @@ class QATRecipeDistributed(FTRecipeInterface):
             weight and activation values to stabilize before fake quantizing them, potentially leading
             to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``.
 
-        - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU
-            is not supported.
+        - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
+            is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
+            done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
+            ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
+            DDP is currently not supported. Training on CPU is not supported.
 
-        - Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
+        - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
             flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
             activations in memory and instead recompute them during the backward pass. This is especially
             helpful for larger batch sizes when you're memory constrained. But these savings in memory
@@ -105,12 +111,12 @@ class QATRecipeDistributed(FTRecipeInterface):
 
     Raises:
         ValueError: If ``dtype`` is set to fp16.
+        RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
     """
 
     def __init__(self, cfg: DictConfig) -> None:
-
         self._device = utils.get_device(device=cfg.device)
-        self._dtype = utils.get_dtype(cfg.dtype, device=self._device)
+        self._dtype = training.get_dtype(cfg.dtype, device=self._device)
 
         if self._dtype == torch.float16:
             raise ValueError(
@@ -119,7 +125,7 @@ def __init__(self, cfg: DictConfig) -> None:
 
         if (
             cfg.get("fsdp_cpu_offload", False)
-            and cfg.get("fused", False)
+            and cfg.optimizer.get("fused", False)
             and not utils.torch_version_ge("2.4.0")
         ):
             raise RuntimeError(
@@ -131,20 +137,29 @@ def __init__(self, cfg: DictConfig) -> None:
         self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
         self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
 
+        if self._log_peak_memory_stats and self._device.type != "cuda":
+            log.info(
+                "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
+            )
+            self._log_peak_memory_stats = False
+
         # _is_rank_zero is used primarily for logging. In the future, the logger
         # should directly take care of this
-        _, rank = utils.get_world_size_and_rank()
+        _, rank = training.get_world_size_and_rank()
         self._is_rank_zero = rank == 0
 
         # Training cfg
         self._resume_from_checkpoint = cfg.resume_from_checkpoint
         self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
+        self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[
+            cfg.get("fsdp_sharding_strategy", "FULL_SHARD")
+        ]
         self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None)
         self._quantizer_mode = None
 
         # These are public properties which are updated by the checkpoint loader
         # when ``resume_from_checkpoint`` is `True` or validated in tests
-        self.seed = utils.set_seed(seed=cfg.seed)
+        self.seed = training.set_seed(seed=cfg.seed)
         self.epochs_run = 0
         self.total_epochs = cfg.epochs
         self.max_steps_per_epoch = cfg.max_steps_per_epoch
@@ -170,28 +185,28 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
         Updates the recipe state from checkpoint.
         """
         try:
-            self.epochs_run = ckpt_dict[utils.EPOCHS_KEY]
+            self.epochs_run = ckpt_dict[training.EPOCHS_KEY]
 
             # on mismatch, warn the user and prevent the override
-            if self.seed != ckpt_dict[utils.SEED_KEY]:
+            if self.seed != ckpt_dict[training.SEED_KEY]:
                 warn(
                     message=(
                         "Config value for seed does not match the checkpoint value, "
-                        f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}"
+                        f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}"
                     )
                 )
-                self.seed = ckpt_dict[utils.SEED_KEY]
-            if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]:
+                self.seed = ckpt_dict[training.SEED_KEY]
+            if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]:
                 warn(
                     message=(
                         "Config value for max_steps_per_epoch does not match the checkpoint value, "
-                        f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}"
+                        f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}"
                     )
                 )
-                self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY]
+                self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY]
 
             # on mismatch, warn the user but allow the override
-            if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]:
+            if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]:
                 warn(
                     message=(
                         "Config value for total_epochs does not match the checkpoint value, "
@@ -207,8 +222,8 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
 
     def setup(self, cfg: DictConfig) -> None:
         """
-        Sets up the recipe state correctly. This includes setting recipe attributes based
-        on the ``resume_from_checkpoint`` flag.
+        Setup the recipe. This includes training state (if resume_from_checkpoint is True),
+        model, tokenizer, loss, optimizer, sampler, and dataloader.
         """
         if self._is_rank_zero:
             self._metric_logger = config.instantiate(cfg.metric_logger)
@@ -216,34 +231,50 @@ def setup(self, cfg: DictConfig) -> None:
             # log config with parameter override
             self._metric_logger.log_config(cfg)
 
-        ckpt_dict = self.load_checkpoint(cfg.checkpointer)
+        checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
 
-        # ``_setup_model`` handles initialization and loading the state dict. This method
-        # should be called before ``_setup_optimizer`` since transforming the optimizer
-        # state dict requires the model
+        self._model_compile = cfg.get("compile", False)
         self._model = self._setup_model(
             cfg_model=cfg.model,
             enable_activation_checkpointing=cfg.enable_activation_checkpointing,
-            memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False),
+            custom_sharded_layers=cfg.get("custom_sharded_layers", None),
             fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
-            model_state_dict=ckpt_dict[utils.MODEL_KEY],
+            reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
+            model_state_dict=checkpoint_dict[training.MODEL_KEY],
             ac_mode=cfg.get("ac_mode", None),
             ac_option=cfg.get("ac_option", None),
             quantizer_cfg=cfg.get("quantizer", None),
         )
-
         self._tokenizer = config.instantiate(cfg.tokenizer)
 
-        # _setup_optimizer should take in ckpt_dict only if training is resumed from
-        # checkpoint. Transforming the opt state dict is handled by this method
         self._optimizer = self._setup_optimizer(
             cfg_optimizer=cfg.optimizer,
-            opt_state_dict=ckpt_dict[utils.OPT_KEY]
-            if self._resume_from_checkpoint
-            else None,
+            opt_state_dict=(
+                checkpoint_dict[training.OPT_KEY]
+                if self._resume_from_checkpoint
+                else None
+            ),
         )
 
+        # initialize loss
         self._loss_fn = config.instantiate(cfg.loss)
+        backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
+        if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
+            # set num_output_chunks for model
+            self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
+            if self._model_compile:
+                log.info("Compiling loss with torch.compile...")
+                # For CEWithChunkedOutputLoss, if we compile the entire class
+                # we lose the benefits from the chunked loss.
+                # Therefore, we only compile the cross entropy function + upcasting
+                self._loss_fn.compute_cross_entropy = torch.compile(
+                    self._loss_fn.compute_cross_entropy, backend=backend
+                )
+        else:
+            if self._model_compile:
+                log.info("Compiling loss with torch.compile...")
+                self._loss_fn = torch.compile(self._loss_fn, backend=backend)
+        log.info("Loss is initialized.")
 
         # sampler and dataloader depend on the tokenizer and loss_fn and should be
         # setup after both of these are initialized
@@ -270,50 +301,110 @@ def setup(self, cfg: DictConfig) -> None:
             self._steps_per_epoch = self.max_steps_per_epoch
         self.global_step = self.epochs_run * self._steps_per_epoch
 
+        # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
+        # if cfg is missing profiler key or if `cfg.profiler.enabled = False`
+        self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
+
+        # Used to ignore labels for loss computation
+        self.ignore_labels_cache = torch.full(
+            (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
+        )
+
+    def _setup_profiler(
+        self, cfg_profiler: Optional[DictConfig] = None
+    ) -> Union[torch.profiler.profile, DummyProfiler]:
+        """
+        Parses the `profiler` section of top-level `cfg` and sets up profiler
+
+        Args:
+            cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
+                `recipe.main`). Default None.
+
+        Returns:
+            profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
+            for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
+            that the instrumented training loop does not need to be changed profiling is disabled.
+
+        The profiler config can be provided in configs under the `profiler` key with the following layout:
+
+        .. code-block:: yaml
+            profiler:
+                enabled: bool
+
+                #Output directory of trace artifacts
+                output_dir: str
+
+            #`torch.profiler.ProfilerActivity` types to trace
+            cpu: bool
+            cuda: bool
+
+                #Trace options
+                profile_memory: bool
+                with_stack: bool
+                record_shapes: bool
+                with_flops: bool
+
+            # `torch.profiler.schedule` options:
+            # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+            wait_steps: int
+            warmup_steps: int
+            active_steps: int
+            num_cycles: int
+        """
+        # Missing profiler section in config, assume disabled
+        if cfg_profiler is None:
+            cfg_profiler = DictConfig({"enabled": False})
+
+        # Check that component is included and set correctly
+        if cfg_profiler.get("_component_", None) is None:
+            cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
+        else:
+            assert (
+                cfg_profiler.get("_component_")
+                == "torchtune.training.setup_torch_profiler"
+            ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"
+
+        profiler, profiler_cfg = config.instantiate(cfg_profiler)
+
+        if self._is_rank_zero:
+            log.info(f" Profiler config after instantiation: {profiler_cfg}")
+
+            self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
+            if profiler_cfg["enabled"]:
+                self.profiler_wait_steps = profiler_cfg["wait_steps"]
+                self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
+                self.profiler_active_steps = profiler_cfg["active_steps"]
+
+        return profiler
+
     def _setup_model(
         self,
         cfg_model: DictConfig,
         enable_activation_checkpointing: bool,
-        memory_efficient_fsdp_wrap: bool,
         fsdp_cpu_offload: bool,
+        reshard_after_forward: bool,
         model_state_dict: Dict[str, Any],
+        custom_sharded_layers: Optional[List[str]] = None,
         ac_mode: Optional[str] = None,
         ac_option: Optional[int] = None,
         quantizer_cfg: Optional[DictConfig] = None,
     ) -> nn.Module:
         """
         Model initialization has some important considerations:
-            a. To minimize GPU peak memory, we load the model on CPU with the right
-               dtype. To ensure that we don't instantiate ``world_size`` number of models,
-               we initialize on meta_device for all ranks other than rank 0.
-            b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the
-               model weights from checkpoint.
-            c. While wrapping the model with FSDP, we set ``sync_module_states``
-               to TRUE and broadcast module params and buffers from rank 0.
-            d. The ``device_id`` param ensures that the FSDP initialization happens on
-               the correct device.
+           a. To minimize GPU peak memory, we initialize the model on meta device with
+              the right dtype
+           b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
+              full state dicts are loaded with ``torch.load(mmap=True)``
         """
-        if self._is_rank_zero:
-            log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...")
-            init_start = time.perf_counter()
-
-            with utils.set_default_dtype(self._dtype):
-                model = config.instantiate(cfg_model)
 
+        if self._is_rank_zero:
             log.info(
-                f"Model instantiation took {time.perf_counter() - init_start:.2f} secs"
+                "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
             )
+            init_start = time.perf_counter()
 
-            # Load both the model weights. This should happen only on Rank 0
-            model.load_state_dict(model_state_dict)
-
-        else:
-            # For non-zero ranks, load the model on meta device
-            with utils.set_default_dtype(self._dtype), torch.device("meta"):
-                model = config.instantiate(cfg_model)
-
-        if self._dtype == torch.bfloat16:
-            model = model.to(torch.bfloat16)
+        with training.set_default_dtype(self._dtype), torch.device("meta"):
+            model = config.instantiate(cfg_model)
 
         # We currently have two versions of activation checkpointing in this recipe
         # for testing and BC purposes. ``enable_activation_checkpointing`` controls
@@ -321,9 +412,6 @@ def _setup_model(
         # ac_mode and ac_option together control selective AC. This is only enabled
         # when these are set AND ``enable_activation_checkpointing`` is set to False
         # We'll clean this up as soon as testing of AC is complete
-        ac_mode = ac_mode
-        ac_option = ac_option
-
         if (not enable_activation_checkpointing) and (ac_mode is not None):
             apply_selective_activation_checkpointing(
                 model,
@@ -331,12 +419,18 @@ def _setup_model(
                 ac_option,
             )
 
+        # original activation checkpointing (full) - flip the condition above
+        if enable_activation_checkpointing and ac_mode is None:
+            training.set_activation_checkpointing(
+                model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
+            )
+
         # Apply quantization-aware training during finetuning
         if quantizer_cfg is None:
             raise ValueError("Quantizer must be specified for QAT recipe.")
         quantizer = config.instantiate(quantizer_cfg)
         quantizer.precision = self._dtype
-        quantizer_mode = utils.quantization.get_quantizer_mode(quantizer)
+        quantizer_mode = training.quantization.get_quantizer_mode(quantizer)
         if "qat" not in quantizer_mode:
             raise ValueError(
                 "Quantizer mode '%s' is not supported for finetuning" % quantizer_mode
@@ -344,43 +438,41 @@ def _setup_model(
         self._quantizer_mode = quantizer_mode
         model = quantizer.prepare(model)
 
-        # Wrap the model with FSDP. This will ensure that the model is sharded
-        # across all available GPUs.
-        model = FSDP(
-            module=model,
-            auto_wrap_policy=utils.get_full_finetune_fsdp_wrap_policy(
-                memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap,
-                modules_to_wrap={modules.TransformerDecoderLayer},
-            ),
-            cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload),
-            sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
-            device_id=self._device,
-            # this recipe does not currently support mixed precision training
-            mixed_precision=None,
-            # Ensure we broadcast params and buffers from rank 0
-            sync_module_states=True,
-            # Initialize empty modules on all non-zero ranks
-            param_init_fn=(
-                lambda module: module.to_empty(
-                    device=torch.device("cuda"), recurse=False
-                )
-                if not self._is_rank_zero
-                else None
-            ),
+        # For FSDP sharding
+        fsdp_shard_conditions = [
+            partial(
+                training.get_shard_conditions,
+                names_to_match=custom_sharded_layers,
+            )
+        ]
+        training.shard_model(
+            model=model,
+            shard_conditions=fsdp_shard_conditions,
+            cpu_offload=fsdp_cpu_offload,
+            reshard_after_forward=reshard_after_forward,
         )
 
-        # Ensure no params and buffers are on meta device
-        utils.validate_no_params_on_meta_device(model)
+        with training.set_default_dtype(self._dtype), self._device:
+            for m in model.modules():
+                # RoPE is not covered in state dict
+                if hasattr(m, "rope_init"):
+                    m.rope_init()
 
-        # original activation checkpointing (full) - flip the condition above
-        if enable_activation_checkpointing and ac_mode is None:
-            utils.set_activation_checkpointing(
-                model, auto_wrap_policy={modules.TransformerDecoderLayer}
-            )
+        # This method will convert the full model state dict into a sharded state
+        # dict and load into the model
+        training.load_from_full_model_state_dict(
+            model, model_state_dict, self._device, self._is_rank_zero, strict=True
+        )
+
+        # Ensure no params and buffers are on meta device
+        training.validate_no_params_on_meta_device(model)
 
         if self._is_rank_zero:
-            memory_stats = utils.get_memory_stats(device=self._device)
-            utils.log_memory_stats(memory_stats)
+            log.info(
+                f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
+            )
+            memory_stats = training.get_memory_stats(device=self._device)
+            training.log_memory_stats(memory_stats)
 
         # synchronize before training begins
         torch.distributed.barrier()
@@ -390,17 +482,13 @@ def _setup_model(
     def _setup_optimizer(
         self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
     ) -> Optimizer:
-        """
-        Set up the optimizer. This method also handles transforing the state dict
-        for FSDP.
-        """
         optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
-
         if opt_state_dict:
-            opt_state_dict = FSDP.optim_state_dict_to_load(
-                self._model, optimizer, opt_state_dict
+            training.load_from_full_optimizer_state_dict(
+                optimizer,
+                opt_state_dict,
+                self._device,
             )
-            optimizer.load_state_dict(opt_state_dict)
 
         if self._is_rank_zero:
             log.info("Optimizer is initialized.")
@@ -417,7 +505,7 @@ def _setup_data(
         DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
         iterable datasets and streaming datasets are not supported.
         """
-        world_size, rank = utils.get_world_size_and_rank()
+        world_size, rank = training.get_world_size_and_rank()
 
         if isinstance(cfg_dataset, ListConfig):
             datasets = [
@@ -431,23 +519,25 @@ def _setup_data(
             packed = cfg_dataset.get("packed", False)
 
         sampler = DistributedSampler(
-            ds,
-            num_replicas=world_size,
-            rank=rank,
-            shuffle=shuffle,
-            seed=0,
+            ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
         )
         dataloader = DataLoader(
             dataset=ds,
             batch_size=batch_size,
             sampler=sampler,
-            collate_fn=partial(
-                utils.padded_collate,
-                padding_idx=self._tokenizer.pad_id,
-                ignore_idx=self._loss_fn.ignore_index,
-            )
-            if not packed
-            else None,
+            # dropping last avoids shape issues with compile + flex attention
+            drop_last=True,
+            collate_fn=(
+                partial(
+                    padded_collate_sft,
+                    padding_idx=self._tokenizer.pad_id,
+                    ignore_idx=self._loss_fn.ignore_index,
+                )
+                if not packed
+                else partial(
+                    padded_collate_packed,
+                )
+            ),
         )
 
         if self._is_rank_zero:
@@ -455,57 +545,71 @@ def _setup_data(
 
         return sampler, dataloader
 
-    def save_checkpoint(self, epoch: int) -> None:
+    def save_checkpoint(
+        self,
+        epoch: int,
+    ) -> None:
         """
-        Save state dict to file. The recipe save_checkpoint method is responsible for
-        correctly creating the checkpoint dict and passing to the checkpointer.
+        Checkpoint the state of the recipe. The constructed checkpoint state dict
+        contains the following information:
+        - Model weights with key training.MODEL_KEY
+        - Relevant recipe state if training is not complete
+
+        Checkpointer will save the model weights and recipe state in
+        different checkpoint files. To correctly resume training from an intermediate checkpoint,
+        the model weights and recipe state must be provided.
         """
+        # final dict passed onto the checkpointer
         checkpoint_dict = {}
 
+        intermediate_checkpoint = epoch + 1 < self.total_epochs
         # To prevent GPU memory from spiking during checkpoint save,
         # we consolidate the full model and optim state dicts on CPU for rank 0
-        with FSDP.state_dict_type(
+        cpu_state_dict = training.get_full_model_state_dict(
             self._model,
-            StateDictType.FULL_STATE_DICT,
-            FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
-            FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
-        ):
-            cpu_state_dict = self._model.state_dict()
-            opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer)
+            self._is_rank_zero,
+        )
+
+        if intermediate_checkpoint:
+            opt_state_dict = training.get_full_optimizer_state_dict(
+                self._optimizer,
+                self._is_rank_zero,
+            )
+        else:
+            opt_state_dict = None
 
         # Now that we have the model and opt state dict, create the actual checkpoint dict
         # to be sent to the checkpointer and ultimately written to file
         if self._is_rank_zero:
 
-            checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict})
+            checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict})
 
-            # if training is in-progress, checkpoint the optimizer state as well
-            if epoch + 1 < self.total_epochs:
+            # if training is in-progress, checkpoint the optimizer state and recipe state
+            # as well.
+            if intermediate_checkpoint:
                 checkpoint_dict.update(
                     {
-                        utils.OPT_KEY: opt_state_dict,
-                        utils.SEED_KEY: self.seed,
-                        utils.EPOCHS_KEY: self.epochs_run,
-                        utils.TOTAL_EPOCHS_KEY: self.total_epochs,
-                        utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
+                        training.OPT_KEY: opt_state_dict,
+                        training.SEED_KEY: self.seed,
+                        training.EPOCHS_KEY: self.epochs_run,
+                        training.TOTAL_EPOCHS_KEY: self.total_epochs,
+                        training.MAX_STEPS_KEY: self.max_steps_per_epoch,
                     }
                 )
 
             self._checkpointer.save_checkpoint(
                 checkpoint_dict,
                 epoch=epoch,
-                intermediate_checkpoint=(epoch + 1 < self.total_epochs),
+                intermediate_checkpoint=intermediate_checkpoint,
             )
 
     def train(self) -> None:
         """
-        The core training loop. Supports training on subsets of the dataset using the
-        ``max_steps_per_epoch``.
+        The core training loop.
         """
         # clean up before training begins
-        utils.cleanup_before_training()
-
-        _, rank = utils.get_world_size_and_rank()
+        training.cleanup_before_training()
+        world_size, rank = training.get_world_size_and_rank()
 
         # zero out the gradients before starting training
         self._optimizer.zero_grad()
@@ -515,6 +619,7 @@ def train(self) -> None:
         running_loss = 0
         num_tokens = 0
 
+        self._profiler.start()
         # self.epochs_run should be non-zero when we're resuming from a checkpoint
         for curr_epoch in range(self.epochs_run, self.total_epochs):
 
@@ -531,6 +636,15 @@ def train(self) -> None:
                 ):
                     break
 
+                # Start tracking CUDA memory for active steps for just the first epoch
+                if (
+                    self._is_rank_zero
+                    and curr_epoch == 0
+                    and self.profiler_profile_memory
+                    and idx == self.profiler_wait_steps + self.profiler_warmup_steps
+                ):
+                    torch.cuda.memory._record_memory_history()
+
                 # Both are shape [b, s]
                 tokens, labels = batch["tokens"], batch["labels"]
                 # Get the attention mask and position ids from the dataset if they
@@ -545,7 +659,7 @@ def train(self) -> None:
                             "Step 0: Disabling fake quant, will re-enable in step %s"
                             % self._fake_quant_after_n_steps
                         )
-                        disable_fq = utils.quantization._get_disable_fake_quant(
+                        disable_fq = training.quantization._get_disable_fake_quant(
                             self._quantizer_mode
                         )
                         self._model.apply(disable_fq)
@@ -554,43 +668,65 @@ def train(self) -> None:
                             "Step %s: Enabling fake quant"
                             % self._fake_quant_after_n_steps
                         )
-                        enable_fq = utils.quantization._get_enable_fake_quant(
+                        enable_fq = training.quantization._get_enable_fake_quant(
                             self._quantizer_mode
                         )
                         self._model.apply(enable_fq)
 
                 tokens = tokens.to(self._device)
-                num_tokens += tokens.numel()
-                labels = labels.to(self._device)
-                mask = mask.to(self._device) if mask is not None else None
-                input_pos = (
-                    input_pos.to(self._device) if input_pos is not None else None
+
+                # Calculate the number of unmasked tokens in the current batch
+                # and increment the total number of tokens seen in the step
+
+                utils.batch_to_device(batch, self._device)
+
+                current_num_tokens = (
+                    batch["labels"] != self._loss_fn.ignore_index
+                ).sum()
+                num_tokens += current_num_tokens
+                labels = batch.pop("labels")
+
+                logits = self._model(**batch)
+
+                # Shift labels to compute loss
+                # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
+                # But this way we dont need to slice the logits. We just add an ignore index to labels.
+                labels = torch.hstack(
+                    (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
                 )
+                if not isinstance(logits, list):
+                    labels = labels.reshape(-1)
+                    logits = logits.reshape(-1, logits.size(-1))
 
-                logits = self._model(tokens, mask=mask, input_pos=input_pos)
-                # Shift so that tokens < n predict n
-                logits = logits[..., :-1, :].contiguous()
-                labels = labels[..., 1:].contiguous()
-                logits = logits.transpose(1, 2)
                 # Compute loss
-                loss = self._loss_fn(logits, labels)
+                current_loss = self._loss_fn(logits, labels) * current_num_tokens
 
-                loss = loss / self._gradient_accumulation_steps
-                running_loss += loss
-                loss.backward()
+                # free logits otherwise it peaks backward memory
+                del logits
+
+                running_loss += current_loss
+                current_loss.backward()
 
                 # Step with optimizer
                 if (idx + 1) % self._gradient_accumulation_steps == 0:
+                    # Get total number of tokens across all ranks to normalize gradients
+                    torch.distributed.all_reduce(num_tokens)
+                    # This will ensure that the logged loss matches what we're optimizing
+                    torch.distributed.all_reduce(running_loss)
+                    # Manually scale the gradients from unnormalized loss by total # of tokens
+                    training.scale_grads(self._model, 1 / num_tokens)
+
                     self._optimizer.step()
                     self._optimizer.zero_grad(set_to_none=True)
 
                     # Update the number of steps when the weights are updated
                     self.global_step += 1
 
-                    loss_to_log = running_loss.item()
+                    loss_to_log = running_loss.item() / num_tokens
+                    self.log_loss(loss_to_log)
                     pbar.update(1)
                     pbar.set_description(
-                        f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}"
+                        f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
                     )
 
                     # Log per-step metrics
@@ -602,10 +738,14 @@ def train(self) -> None:
                         log_dict = {
                             "loss": loss_to_log,
                             "lr": self._optimizer.param_groups[0]["lr"],
-                            "tokens_per_second_per_gpu": num_tokens / time_per_step,
+                            "tokens_per_second_per_gpu": (
+                                num_tokens / time_per_step * world_size
+                            ),
                         }
                         if self._log_peak_memory_stats:
-                            log_dict.update(utils.get_memory_stats(device=self._device))
+                            log_dict.update(
+                                training.get_memory_stats(device=self._device)
+                            )
                         self._metric_logger.log_dict(
                             log_dict,
                             step=self.global_step,
@@ -616,13 +756,65 @@ def train(self) -> None:
                     num_tokens = 0
                     t0 = time.perf_counter()
 
+                    # Stop tracking CUDA memory now that active steps are complete
+                    if (
+                        self._is_rank_zero
+                        and curr_epoch == 0
+                        and self.profiler_profile_memory
+                        and idx
+                        == self.profiler_wait_steps
+                        + self.profiler_warmup_steps
+                        + self.profiler_active_steps
+                    ):
+                        torch.cuda.memory._record_memory_history(enabled=None)
+
+                    # Step profiler
+                    # Note that this is called within gradient accumulation block, hence
+                    # will include multiple forward / backward passes if gradient accumulation > 1
+                    self._profiler.step()
+
             self.epochs_run += 1
             self.save_checkpoint(epoch=curr_epoch)
 
+        self._profiler.stop()
+
     def cleanup(self) -> None:
         if self._is_rank_zero:
             self._metric_logger.close()
-        torch.distributed.destroy_process_group()
+        destroy_process_group()
+
+
+    def log_loss(self, loss):
+        pass
+
+
+
+
+
+def prepare_voir(recipe):
+    from benchmate.observer import BenchObserver
+    from benchmate.monitor import bench_monitor
+
+    def batch_size(x):
+        bs, token = x["tokens"].shape
+        return bs * token
+
+    observer = BenchObserver(
+        earlystop=30,
+        raise_stop_program=True,
+        batch_size_fn=batch_size,
+        stdout=True
+    )
+
+    def on_loss(loss):
+        observer.record_loss(loss)
+        observer.step()
+
+    recipe._dataloader = observer.loader(recipe._dataloader, custom_step=True)
+    recipe.log_loss = on_loss
+
+    return observer, bench_monitor
+
 
 
 @config.parse
@@ -634,17 +826,16 @@ def recipe_main(cfg: DictConfig) -> None:
         - Parameters specified in config (see available configs through ``tune ls``)
         - Overwritten by arguments from the command-line
     """
-    if not utils.is_distributed():
+    if not training.is_distributed():
         raise RuntimeError(
             "Distributed QAT recipe should be run via a distributed launcher."
             "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
         )
-
     init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
     if cfg.get("fsdp_cpu_offload", False):
         # Utilize all available CPU cores for intra-op parallelism. This provides ~2x
         # speed up when benchmarking fused AdamW on CPU
-        utils.set_torch_num_threads()
+        training.set_torch_num_threads()
 
     config.log_config(recipe_name="QATRecipeDistributed", cfg=cfg)
 
@@ -655,4 +846,4 @@ def recipe_main(cfg: DictConfig) -> None:
 
 
 if __name__ == "__main__":
-    sys.exit(recipe_main())
+    sys.exit(recipe_main())
\ No newline at end of file
diff --git a/benchmarks/llm/requirements.rocm.txt b/benchmarks/llm/requirements.rocm.txt
index ab5098d0..0709f0cf 100644
--- a/benchmarks/llm/requirements.rocm.txt
+++ b/benchmarks/llm/requirements.rocm.txt
@@ -200,7 +200,7 @@ python-dateutil==2.9.0.post0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   pandas
-pytorch-triton-rocm==3.0.0
+pytorch-triton-rocm==3.1.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
@@ -246,7 +246,7 @@ six==1.16.0
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   asttokens
     #   python-dateutil
-sympy==1.13.2
+sympy==1.13.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torch
@@ -254,15 +254,16 @@ tiktoken==0.7.0
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchtune
-torch==2.4.0+rocm6.0
+torch==2.5.1+rocm6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/llm/requirements.in
-torchao==0.3.1
+torchao==0.6.1
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   torchtune
-torchtune==0.2.1
+#torchtune==0.2.1
+git+https://github.com/pytorch/torchtune.git@e1caa9f82fea24d728f9b244a9dd1957f5ed7465 
     # via
     #   -c .pin/../.pin/constraints-rocm-torch.txt
     #   -r benchmarks/llm/requirements.in
diff --git a/config/base.yaml b/config/base.yaml
index 3fea53e5..9e470fa3 100644
--- a/config/base.yaml
+++ b/config/base.yaml
@@ -555,7 +555,7 @@ llm-lora-single:
     method: per_gpu
 
   argv:
-    "{milabench_code}/recipes/lora_finetune_single_device.py": true
+    tuneworkaroundrecipes/lora_finetune_single_device.py: true
     --config: "{milabench_code}/configs/llama3_8B_lora_single_device.yaml"
     epochs=1: true
     output_dir={milabench_extra}/output: true
@@ -576,7 +576,8 @@ llm-lora-ddp-gpus:
   tags:
     - multigpu
   argv:
-    "{milabench_code}/recipes/lora_finetune_distributed.py": true
+    #"{milabench_code}/recipes/lora_finetune_distributed.py": true
+    tuneworkaroundrecipes/lora_finetune_distributed.py: true
     --config: "{milabench_code}/configs/llama3_8B_lora_single_device.yaml"
     epochs=1: true
     output_dir={milabench_extra}/output: true
@@ -589,6 +590,8 @@ llm-lora-ddp-gpus:
     gradient_accumulation_steps=8: true
 
 
+
+
 llm-lora-ddp-nodes:
   tags:
     - multinode
@@ -599,7 +602,7 @@ llm-lora-ddp-nodes:
     n: 1
   
   argv:
-    "{milabench_code}/recipes/lora_finetune_distributed.py": true
+    tuneworkaroundrecipes/lora_finetune_distributed.py: true
     --config: "{milabench_code}/configs/llama3_8B_lora_single_device.yaml"
     epochs=1: true
     output_dir={milabench_extra}/output: true
@@ -638,6 +641,7 @@ llm-lora-mp-gpus:
     batch_size=8: true
     gradient_accumulation_steps=1: true
 
+
 llm-full-mp-gpus:
   inherits: _llm
   tags:
@@ -663,6 +667,34 @@ llm-full-mp-gpus:
 
 
 
+# This is a proxy test config to check llm-full-mp-gpus recipe on devices that don't have the memory
+# to run the 70B parameter model.
+# llm-full-mp-gpus:
+#   inherits: _llm
+#   tags:
+#     - multigpu
+#   plan:
+#     method: njobs
+#     n: 1
+
+#   argv:
+#     #"{milabench_code}/recipes/full_finetune_distributed.py": true
+#     tuneworkaroundrecipes.full_finetune_distributed.py: true
+#     --config: "{milabench_code}/configs/llama3_8B_full.yaml"
+#     epochs=1: true
+#     output_dir={milabench_extra}/output: true
+#     tokenizer.path={milabench_data}/llama3_8B/original/tokenizer.model: true
+#     checkpointer.checkpoint_dir={milabench_data}/llama3_8B/: true
+#     checkpointer.output_dir={milabench_data}/llama3_8B/: true
+#     metric_logger.log_dir={milabench_extra}/metrics: true
+#     repo_id="meta-llama/Meta-Llama-3.1-8B": true
+#     safetensors=true: true
+#     batch_size=2: true
+#     gradient_accumulation_steps=1: true
+
+
+
+
 llm-full-mp-nodes:
   tags:
     - multinode
@@ -673,7 +705,7 @@ llm-full-mp-nodes:
     n: 1
 
   argv:
-    "{milabench_code}/recipes/full_finetune_distributed.py": true
+    tuneworkaroundrecipes/full_finetune_distributed.py: true
     --config: "{milabench_code}/configs/llama3_70B_full.yaml"
     epochs=1: true
     output_dir={milabench_extra}/output: true