Skip to content

Commit

Permalink
Updated recipes and configs
Browse files Browse the repository at this point in the history
  • Loading branch information
rkarhila-amd committed Nov 14, 2024
1 parent 3d83577 commit 37b35f7
Show file tree
Hide file tree
Showing 9 changed files with 489 additions and 202 deletions.
15 changes: 8 additions & 7 deletions .pin/constraints-rocm-torch.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion benchmarks/llm/configs/llama3_70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ checkpointer:
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
batch_size: 1 #2
epochs: 3

optimizer:
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/llm/configs/llama3_70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ model:
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 16
lora_alpha: 32
lora_rank: 8 #16
lora_alpha: 16 #32

tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
Expand Down Expand Up @@ -68,7 +68,7 @@ dataset:
_component_: torchtune.datasets.alpaca_dataset
seed: null
shuffle: True
batch_size: 2
batch_size: 1 #2

# Optimizer and Scheduler
optimizer:
Expand Down
10 changes: 9 additions & 1 deletion benchmarks/llm/recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
#!/usr/bin/env python
#!/usr/bin/env python3
# # As of November 2024, the development of torchrun is very rapid.
# This is the recipe based on torchrun recipe git commit e137afe (post release 0.3.1)
# https://github.com/pytorch/torchtune/blob/7bfb3336446f0d874ab5d4595249839b735b7076/recipes/full_finetune_single_device.py

# Torchtune 0.2.1 recipe with device instrumenation (c) Mila
# https://github.com/mila-iqia/milabench/blob/a60a3aae21e87e46bcce403620a3f56c12878554/benchmarks/llm/recipes/full_finetune_single_device.py

# The instrumentation edits (c) AMD

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
Expand Down
53 changes: 50 additions & 3 deletions benchmarks/llm/recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -270,7 +272,6 @@ def setup(self, cfg: DictConfig) -> None:

checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
self._compile = cfg.get("compile", False)

self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=self._enable_activation_checkpointing,
Expand Down Expand Up @@ -453,15 +454,33 @@ def _setup_model(
with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

if self._is_rank_zero:
log.info(
"model instantiated based on config on Rank 0 ..."
)


set_trainable_params(model, get_adapter_params(model))

if self._is_rank_zero:
log.info(
"trainable parameters set..."
)


if self._compile:
if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"compiling..." )

training.compile_model(model, verbose=self._is_rank_zero)
if self._is_rank_zero: log.info( str(self._is_rank_zero)+ " " + "done compiling" )


if enable_activation_checkpointing:
if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"setting activation checkpointing" )
training.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done setting activation checkpointing" )

# For FSDP sharding
fsdp_shard_conditions = [
Expand All @@ -470,40 +489,54 @@ def _setup_model(
names_to_match=custom_sharded_layers,
)
]

if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"sharding model" )

training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
)

if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done sharding model" )

if lora_weights_state_dict:
if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"loading state dict" )

lora_missing, lora_unexpected = training.load_from_full_model_state_dict(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"loaded state dict" )

else:
lora_missing, lora_unexpected = None, None

# Initialize LoRA params and RoPE buffers
with training.set_default_dtype(self._dtype), self._device:
lora_device = "cpu" if fsdp_cpu_offload else self._device
for m in model.modules():
for i,m in enumerate(model.modules()):
if (
isinstance(m, LoRALinear) or isinstance(m, DoRALinear)
) and not lora_weights_state_dict:

# lora may not be covered in state dict
# if finetune for the 1st time
m.lora_a.to_empty(device=lora_device)
m.lora_b.to_empty(device=lora_device)
m.initialize_parameters()

# RoPE is not covered in state dict
if hasattr(m, "rope_init"):
m.rope_init()


if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"load from full model state dict" )

base_missing, base_unexpected = training.load_from_full_model_state_dict(
model,
base_model_state_dict,
Expand All @@ -515,9 +548,15 @@ def _setup_model(
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
is_dora = True
if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"init dora" )
m.initialize_dora_magnitude()
if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done init dora" )

if is_dora:
if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"load dora" )
load_dora_magnitudes(model)
if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done load dora" )
if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"validate missing" )
validate_missing_and_unexpected_for_lora(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
Expand All @@ -527,14 +566,22 @@ def _setup_model(
lora_missing=lora_missing,
lora_unexpected=lora_unexpected,
)
if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done validate missing" )

if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"validate no params" )

# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)

if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done validate no params" )


if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"activation offload" )
# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)

if self._is_rank_zero: log.info( str(self._is_rank_zero)+" "+"done activation offload" )
# log
if self._is_rank_zero:
log.info(
Expand Down
9 changes: 8 additions & 1 deletion benchmarks/llm/recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
#!/usr/bin/env python
#!/usr/bin/env python3
# As of November 2024, the development of torchrun is very rapid.
# This is the recipe based on torchrun recipe git commit e137afe (post release 0.3.1)
# https://github.com/pytorch/torchtune/blob/7bfb3336446f0d874ab5d4595249839b735b7076/recipes/lora_finetune_single_device.py

# Torchtune 0.2.1 recipe with device instrumenation (c) Mila
# https://github.com/mila-iqia/milabench/blob/a60a3aae21e87e46bcce403620a3f56c12878554/benchmarks/llm/recipes/lora_finetune_single_device.py

# The instrumentation edits (c) AMD
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down
Loading

0 comments on commit 37b35f7

Please sign in to comment.