Skip to content

Commit

Permalink
Started instrumeting recipes from newer torchtune for milabench
Browse files Browse the repository at this point in the history
  • Loading branch information
rkarhila-amd committed Nov 12, 2024
1 parent a60a3aa commit 0fa1419
Show file tree
Hide file tree
Showing 10 changed files with 997 additions and 494 deletions.
5 changes: 3 additions & 2 deletions benchmarks/llm/configs/llama3_70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ model:

safetensors: true
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files: [
model-00001-of-00030.safetensors,
Expand Down Expand Up @@ -95,6 +95,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: True
memory_efficient_fsdp_wrap: True
fsdp_cpu_offload: True

Expand All @@ -103,7 +104,7 @@ dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-llama3-finetune
log_every_n_steps: 1
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/llm/configs/llama3_70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ tokenizer:

safetensors: true
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files: [
model-00001-of-00030.safetensors,
Expand Down Expand Up @@ -90,7 +90,7 @@ gradient_accumulation_steps: 1
# Logging
output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False
Expand All @@ -99,3 +99,4 @@ log_peak_memory_stats: False
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: True
4 changes: 2 additions & 2 deletions benchmarks/llm/configs/llama3_8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
lora_alpha: 16

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
checkpoint_files: [
consolidated.00.pth
Expand Down Expand Up @@ -69,7 +69,7 @@ gradient_accumulation_steps: 32
# Logging
output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/llm/configs/llama3_8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ tokenizer:
path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
checkpoint_files: [
consolidated.00.pth
Expand Down Expand Up @@ -69,7 +69,7 @@ compile: False
# Logging
output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False
Expand All @@ -81,7 +81,7 @@ enable_activation_checkpointing: True

# Profiler (disabled)
profiler:
_component_: torchtune.utils.setup_torch_profiler
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/llm/configs/llama3_8B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ model:
_component_: torchtune.models.llama3_1.llama3_1_8b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
checkpoint_files: [
consolidated.00.pth
Expand All @@ -45,7 +45,7 @@ epochs: 3

# QAT arguments
quantizer:
_component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256

optimizer:
Expand All @@ -70,7 +70,7 @@ dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-llama3-finetune
log_every_n_steps: 1
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/llm/configs/llama3_8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ tokenizer:
path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
checkpoint_files: [
consolidated.00.pth
Expand Down Expand Up @@ -68,7 +68,7 @@ compile: False
# Logging
output_dir: /tmp/qlora_finetune_output/
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False
Expand All @@ -80,7 +80,7 @@ enable_activation_checkpointing: True

# Profiler (disabled)
profiler:
_component_: torchtune.utils.setup_torch_profiler
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
Expand Down
Loading

0 comments on commit 0fa1419

Please sign in to comment.