v0.1, 9/24/2024
Add the necessary imports:
from torch.profiler import profile, record_function, ProfilerActivity
In a distributed training setup, where multiple processes (or GPUs) are involved, it's essential to ensure that each process writes to a separate trace file to avoid race conditions. You can achieve this by including the process rank in the trace filename. The trace files generated can then be visualized using tools like Perfetto or Chrome trace viewer.
Ensure that the directory where traces are stored has the proper write permissions for all processes.
def trace_handler(p):
# Get the rank of the current process (GPU)
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
# Generate a unique trace file for each process using its rank
trace_filename = f"/path/to/traces/trace_model_size_rank{rank}_step{p.step_num}.json"
# Export the trace for analysis in Perfetto or Chrome tracing
p.export_chrome_trace(trace_filename)
Wrap the training loop in this profiling context to capture the trace. Make sure that everything you want to capture (dataloading, communication, checkpointing, etc) is included within this context. Step forward the profiler at the end of each training iteration.
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
active=5,),
on_trace_ready=trace_handler
) as p:
for batch in dataloader:
train_step(batch)
p.step()
More information on how to set the schedule parameters to efficiently analyze long running training jobs can be found here: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-profiler-to-analyze-long-running-jobs
The profiler trace will be saved to trace.json
. Open this file in a trace viewer like Perfetto to visualize the trace.
While scaling the RegNet+GPT workload to multiple nodes, we observed that some model configurations showed very large performance degradation. Here we show an example of how we used the profiler to give us clues that helped debug this issue.
First, we isolated just the GPT part and measured its multi-node scaling. We observed that the performance drop was small and the NCCL/RCCL communication kernels took ~1ms each:
Next we profiled the combined RegNet+GPT workload that showed very poor multinode scaling. The trace showed us that the communication kernels for the convolutional RegNet did not cause a major bottleneck and that the GPU utilization for this network was good. This ruled out any inefficiencies in the RegNet being the cause of the poor scaling. However, we noticed that the same communication kernels took around 500x longer in the combined RegNet+GPT case than in the isolated GPT case:
From this observation, we hypothesized that running the model with close to full GPU memory utilization causes smaller communication buffers to be allocated, resulting in greater transmission delay.
Based on this hypothesis, we tested the RegNet+GPT workload with a batch size that is 1 smaller than the maximum possible batch size. With this reduced batch size, we see the communication bottleneck alleviated and more reasonable multi-node scaling.
In the RegNet+GPT workload, the random data generation and transfer to GPU memory takes up a significant amount of time and is not overlapped with computation, leaving the GPU idle. This bottleneck can be identified using the CPU and GPU trace visualization as shown in the example screenshot below:
Prerequisite: PyTorch - Understanding GPU Memory
To analyze memory usage during the training process, we can use PyTorch's CUDA memory management APIs to record memory events and generate memory snapshots. This provides detailed information about memory allocations and frees, which can be visualized for better understanding and optimization.
First, we start by recording the memory allocation and free events over the course of the training loop. By setting up the recording mechanism before the training loop, we can capture memory-related events like allocations, deallocations, and re-use of memory blocks.
Example code snippet to start and stop memory history recording:
def start_record_memory_history() -> None:
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not recording memory history")
return
logger.info("Starting snapshot record_memory_history")
torch.cuda.memory._record_memory_history(
max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
)
def stop_record_memory_history() -> None:
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not recording memory history")
return
logger.info("Stopping snapshot record_memory_history")
torch.cuda.memory._record_memory_history(enabled=None)
In this function, we specify MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
to limit the number of memory events to capture, ensuring we don’t overwhelm the system with excessive data for long-running jobs.
To visualize memory usage, we can export memory snapshots at any point during or after the training. These snapshots contain detailed information about memory allocations and can be saved to disk in a format that can later be analyzed using external tools or custom scripts.
def export_memory_snapshot() -> None:
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not exporting memory snapshot")
return
# Prefix for file names.
host_name = socket.gethostname()
timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = f"{host_name}_{timestamp}"
try:
logger.info(f"Saving snapshot to local file: {file_prefix}.pickle")
torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
except Exception as e:
logger.error(f"Failed to capture memory snapshot {e}")
return
This method generates a memory snapshot and saves it as a .pickle
file, which can be visualized by dragging and dropping into the user interface provided by PyTorch at PyTorch Memory Visualization. This tool allows users to adjust the level of detail by filtering out smaller memory events to simplify the view.
The image above illustrates a memory profile. The upward slope represents the forward pass where the activations are allocated, highlighting the increasing memory usage as more activations are stored. Conversely, the downward slope indicates the backward pass where gradients are computed and activations are deallocated, freeing up memory. The horizontal lines spanning the full timeline represent static memory, such as model parameters and optimizer state. Hovering over the user interface provides detailed information about memory allocations, including the memory address and the specific code trace that triggered the allocation. It also shows the size of each allocation, helping developers understand and optimize memory usage during model training.