First Steps With Distributed Data Parallelism and PyTorch Profiling
2025-09-24
In this post we will walk through implementing a simplified version of Distributed Data Parallelism (DDP) and improve it in steps and use it to train a model using DDP on a single node with four Nvidia T4 GPUs. But before we train on multiple GPUs we will begin by training on a single GPU and work our way up step by step.
Along the way we will use the PyTorch profiler to visualize the execution and compare the traces to see the differences in execution during the different training runs.
We'll run our tests with the SmolLM2-360M-Instruct model and we'll use the nbdistributed library to run distributed training in a notebook.
The nbdistributed library was created by @TheZachMueller, and it's an awesome way to interact with multiple GPUs inside of a Jupyter notebook. I used Modal Notebooks to run this code.
If you want to learn more about large scale model training you should sign up for Zach's course!
What is Distributed Data Parallelism?
Before we can implement distributed data parallelism in our code, we must understand it. In data parallelism, the "parallelism" comes from training multiple instances of our model on different shards of our dataset.
In each iteration of the training loop our models train in parallel on their respective batch of data. We can think of this set of batches as a "global batch" in the data parallel world while considering each batch as a "micro batch".
In this exploration we do not implement any model sharding strategies which means we will allocate one model per GPU.
Setup
Install dependencies
!uv pip install --system datasets "git+https://github.com/muellerzr/nbdistributed"
Load and start nbdistributed
%load_ext nbdistributed
We initialize nbdistributed with four GPUs!
%dist_init --num-processes 4
Show output
Starting 4 distributed workers...
✓ Successfully started 4 workers
Available commands:
  %%distributed - Execute code on all ranks (explicit)
  %%rank [0,n] - Execute code on specific ranks
  %sync - Synchronize all ranks
  %dist_status - Show worker status
  %dist_mode - Toggle automatic distributed mode
  %dist_shutdown - Shutdown workers
🚀 Distributed mode active: All cells will now execute on workers automatically!
   Magic commands (%, %%) will still execute locally as normal.
🐍 Below are auto-imported and special variables auto-generated into the namespace to use
  `torch`
  `dist`: `torch.distributed` import alias
  `rank` (`int`): The local rank
  `world_size` (`int`): The global world size
  `gpu_id` (`int`): The specific GPU ID assigned to this worker
  `device` (`torch.device`): The current PyTorch device object (e.g. `cuda:1`)
Define helper functions
import random
import numpy as np
import torch
# Enable TF32 for matrix multiplications
torch.backends.cuda.matmul.allow_tf32 = True
# Enable TF32 for cuDNN (convolution operations)
torch.backends.cudnn.allow_tf32 = True
common_seed = 314
def set_seed(seed: int = common_seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL_NAME = "HuggingFaceTB/SmolLM2-360M-Instruct"
def get_smol_model():
    return AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME, num_labels=2, torch_dtype="bfloat16"
    )
def get_smol_tokenizer():
    return AutoTokenizer.from_pretrained(MODEL_NAME)
from datasets import load_dataset
def get_dataset():
    dataset = load_dataset("glue", "mrpc")
    tokenizer = get_smol_tokenizer()
    def tokenize_func(examples):
        return tokenizer(
            examples["sentence1"],
            examples["sentence2"],
            max_length=None,
            truncation=True,
            padding=True,
        )
    dataset = dataset.map(
        tokenize_func, batched=True, remove_columns=["idx", "sentence1", "sentence2"]
    )
    dataset = dataset.rename_columns({"label": "labels"})
    return dataset
Let's use the helper functions we defined to load our tokenizer and dataset.
tokenizer = get_smol_tokenizer()
dataset = get_dataset()["train"]
Now we define another helper function for creating a dataloader.
from torch.utils.data import DataLoader
train_ds = dataset.shuffle(seed=common_seed)
def collate_func(batch):
    return tokenizer.pad(
        batch,
        padding="longest",
        max_length=None,
        pad_to_multiple_of=8,
        return_tensors="pt",
    )
def get_dataloader(ds, batch_size = 16, seed = common_seed):
    return DataLoader(
        ds,
        batch_size=batch_size,
        collate_fn=collate_func,
        drop_last=True,
        shuffle=True
    )
Setup some helper functions for using the HolisticTraceAnalysis library to analyze our traces.
from hta.trace_analysis import TraceAnalysis
import json
def get_trace_dicts(trace_dir: str):
    analyzer = TraceAnalysis(trace_dir = trace_dir)
    time_spent_df = analyzer.get_temporal_breakdown().to_dict()
    overlap_df = analyzer.get_comm_comp_overlap().to_dict()
    kernel_type_metrics_df, _ = analyzer.get_gpu_kernel_breakdown()
    result = {
        "time_spent_df": time_spent_df,
        "overlap_df": overlap_df,
        "kernel_type_metrics_df": kernel_type_metrics_df.to_dict(),
    }
    return result
Single GPU Training
Data Preparation and Initializing the Models
Let's use the helper functions we just defined to treain a model on a single GPU instance. We initialize the model, optimizer, and dataloader below.
%%rank [0]
# This tells `nbdistributed` to only execute this code cell on GPU rank 0
model = get_smol_model() # Initialize the model
model.to(device) # Send it to the GPU
model.train()
# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_ds = dataset.shuffle(seed=common_seed) # Shuffle the dataset using our seed for consistency
train_dataloader = get_dataloader(train_ds) # Create a dataloader from our training dataset
Profiling
We can use the PyTorch profiler to visualize execution time and how our memory is being used in the GPU. To learn more about getting started with the PyTorch profiler, refer to the section on profiling in The Ultra Scale Playbook, the PyTorch Profiler documentation, and this tutorial recipe. Once we've used the profiler to capture our traces, we can analyze them using the HolisticTraceAnalysis library.
We call torch.profiler.profile which creates a profiler context. Then we begin our training loop like normal and simply call profiler.step() as the last step in the loop.
Let's create a helper function to create a profiler instance.
def get_profiler_tb(file_prefix: str, steps: int = 6):
    worker_name=f"gpu-rank-0{dist.get_rank()}"
    tensorboard_trace_handler = torch.profiler.tensorboard_trace_handler(
        f"{file_prefix}",
        worker_name=worker_name,
        use_gzip=True,
    )
    def trace_handler(prof: torch.profiler.profile):
        tensorboard_trace_handler(prof)
    return torch.profiler.profile(
        activities=[
            # profile activity on the CPU and GPU
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        # Setup the profiler schedule
        schedule=torch.profiler.schedule(wait=0, warmup=1, active=steps - 1),
        # Save the trace files when it is ready
        on_trace_ready=trace_handler,
        profile_memory=True,
        record_shapes=True,
        with_stack=True,
    )
There are multiple options which can be configured when setting up the PyTorch profiler, above we've set the profiler up to profile the CPU and GPU. The profiler schedule can be passed various parameters, our configuration below will warmup for 1 step and record the last 5 of the 6 total training steps.
%%rank [0]
max_steps = 6
with get_profiler("/workspace/ddp/single") as prof:
    for (i, batch) in enumerate(train_dataloader):
        if i >= max_steps:
            break
        batch = {k: v.to(device) for k, v in batch.items()}
        output = model(**batch) # Make a prediction
        output.loss.backward() # Backpropagate to calculate the gradients
        # Step the optimizer
        optimizer.step()
        optimizer.zero_grad()
        # Step the profiler
        prof.step()
After we run this there will be a trace file located inside the single-gpu directory. This file can be opened in Chrome's trace viewer where we can see the GPU execution. The trace viewer can be accessed by entering chrome://tracing into the URL bar.
Let's look at our trace!
From top to bottom the rows represent execution time during:
- Forward pass on the CPU - thread 29 (python)
- Backward pass on the CPU - thread 208 (pt_autograd_0)
- GPU kernel execution - stream 7
The fact that we do not see any gaps in the GPU stream means our GPU is not idling and our utilization is consistent throughout training. That is what we would expect when training on a single GPU.
The results of HolisticTraceAnalaysis are rendered in the charts above.
- In the "Temporal Breakdown" column chart we can see that > 98% of GPU time was spent doing computations.
- The "Computation-Communication Overlap" column chart shows that there is no computation occuring while communication is occuring and therefore they do not overlap. This is expected because we are only using a single GPU and don't need to communicate information to any other GPUs.
- The "Kernel Type Distribution" pie chart shows the distribution of kernal operation types. For this post we won't be focusing on it but the data is still interesting to visualize and look at.
We will be looking at the same data from other approaches to compare throughout the post.
For convenience (and fun!) I've also embedded a trace viewer into this post; clicking the "Show trace explorer" heading below will reveal it.
Show trace explorer
The trace files are loaded from my server and might be slow to load. This blog is hosted from a kubernetes cluster I run in my home, hopefully it doesn't blow up 😅. You can get a live view of what is running in my cluster on the home page!
For a better viewing experience, minimize the sidebar by clicking the "hamburger" (
) icon for a better view and click the "expand all" ( ) button on the left side a bit below. 
Free memory
%%rank [0]
del model.model, model
del optimizer
torch.cuda.empty_cache()
Multi GPU Training
Parallel Code Execution on Multiple GPUs
In the realm of multi GPU training, our programs operate in a single program, multiple data (SPMD) computing model. The same code we write, in this case the "single program", will run on every GPU but we will have to use conditional statements (if/else) to run code on specific GPUs.
Each GPU is assigned a rank which we can think of simply as a numerical id for a particular GPU. We can get this rank by calling dist.get_rank(), and get the total number of available GPUs by calling dist.get_world_size().
ws = dist.get_world_size() # The total number of GPUs
rank = dist.get_rank()     # The ID of the current GPU this code is running on
ws, rank
Show output
🔹 Rank 0:
  (4, 0)
🔹 Rank 1:
  (4, 1)
🔹 Rank 2:
  (4, 2)
🔹 Rank 3:
  (4, 3)
We can run code on a specific GPU by adding a conditional statement to a section of code with a check on the value of rank. For example
if rank == 0:
    print(f"This should execute on the first GPU rank {rank}")
else:
    print(f"This is running on GPU rank {rank}")
Show output
🔹 Rank 0:
  This should execute on the first GPU rank 0
🔹 Rank 1:
  This is running on GPU rank 1
🔹 Rank 2:
  This is running on GPU rank 2
🔹 Rank 3:
  This is running on GPU rank 3
Sharding the Dataset
First lets make sure we are feeding each instance of our model a different shard of our dataset. While we want to keep the parameters synchronized, we also want to make sure each instance is exposed to different chunks of our data to benefit from parallelism. We won't benefit from any speed up if the instances of our model are training on the same batches of data. This would result in us iterating through our entire dataset on every GPU and running single GPU training 4 times!
Since the value of rank will be unique on each GPU, we can use it to partition our dataset and ensure each shard is different.
ws = dist.get_world_size()
rank = dist.get_rank()
train_ds = dataset.shuffle(seed=common_seed)
# To ensure equal distribution of the dataset, we calculate the dataset length per
# rank/GPU by dividing total dataset length by world size
ds_length = len(train_ds)
ds_length_per_gpu = ds_length // ws
# We calculate the start and end index of the data shard
# Remember the value of `rank` will be unique on each GPU
start = rank * ds_length_per_gpu
end = start + ds_length_per_gpu if rank != ws - 1 else ds_length
train_shard = train_ds.select(list(range(start, end)))
Let's confirm that the first batch of data is different on each GPU.
train_dataloader = get_dataloader(train_shard)
batch = next(iter(train_dataloader))
batch['input_ids'][0][:5]
Show output
🔹 Rank 0:
  tensor([14550,    82,  6538,   637,    99])
🔹 Rank 1:
  tensor([3750, 4288, 5397,  330, 9478])
🔹 Rank 2:
  tensor([ 504, 9238,  436,  932,  411])
🔹 Rank 3:
  tensor([14086,   659, 10298,   284,   260])
Initializing the Models
When we initialize a model on multiple GPUs, the parameters of each models get their own random initialization. Without synchronization each GPU would start training from a different set of weights, which means we are essentially training different models. We want to ensure every model starts from the same point in parameter space and we want to keep the parameters synchronized throughout training.
On initialization, we can synchronize the models across our GPUs by broadcasting the parameters from one GPU to all of the others by starting from the same random seed. Then we will use the dist.broadcast operation from torch.distributed to send the parameters from one GPU to all of the others to verify that the parameters are the same to start.
Remember, we are training one model instance per GPU!
But first let's try demonstrating that the parameters of our models are not the same if we initialize them naively without syncrhonizing.
model = get_smol_model()
model.to(device)
model.train()
# We check every parameter of the model
for p in model.parameters():
    # First we allocate some space for receiving the broadcasted parameter.
    rank0_param = p.data.clone()
    # Broadcast the `rank0_param` values from the rank 0 GPU into `rank0_param` on other
    # GPUs, they are the same shape.
    dist.broadcast(rank0_param, src=0)
    # Raise an error if they ar enot equal
    if not torch.equal(p.data, rank0_param):
        raise ValueError("Model parameters are not in sync")
Show output
❌ Error on Rank 1: Model parameters are not in sync
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 282, in _execute_code_streaming
    tree = ast.parse(code.strip(), mode='eval')
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/ast.py", line 52, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 1
    model = get_smol_model()
          ^
SyntaxError: invalid syntax
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 367, in _execute_code_streaming
    exec(compile(tree, '<string>', 'exec'), self.namespace)
  File "<string>", line 10, in <module>
ValueError: Model parameters are not in sync
❌ Error on Rank 2: Model parameters are not in sync
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 282, in _execute_code_streaming
    tree = ast.parse(code.strip(), mode='eval')
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/ast.py", line 52, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 1
    model = get_smol_model()
          ^
SyntaxError: invalid syntax
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 367, in _execute_code_streaming
    exec(compile(tree, '<string>', 'exec'), self.namespace)
  File "<string>", line 10, in <module>
ValueError: Model parameters are not in sync
❌ Error on Rank 3: Model parameters are not in sync
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 282, in _execute_code_streaming
    tree = ast.parse(code.strip(), mode='eval')
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/ast.py", line 52, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 1
    model = get_smol_model()
          ^
SyntaxError: invalid syntax
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 367, in _execute_code_streaming
    exec(compile(tree, '<string>', 'exec'), self.namespace)
  File "<string>", line 10, in <module>
ValueError: Model parameters are not in sync
This fails because we first need to set the same seed across GPUs to ensure every model's parameters are initialized to the same value. Looking at the error message carefully we can see there is an error on GPUs 1-3 (rank 1-3) and no error on GPU 0 (rank 0). We are broadcasting the parameter from GPU 0 so when we check if the parameters are equal, they are the same parameter.
Let's ensure the parameters are initialized to the same values by setting the same seed on all GPUs and run the same code as before and observe the difference.
set_seed() # Set the random seed to ensure every model starts with the same parameters
model = get_smol_model()
model.to(device)
model.train()
for p in model.parameters():
    # Allocate space for the parameter and broadcast from rank 0 to all other ranks.
    rank0_param = p.data.clone()
    dist.broadcast(rank0_param, src=0)
    if not torch.equal(p.data, rank0_param):
        raise ValueError("Model parameters are not in sync")
Free memory
del model.model, model
torch.cuda.empty_cache()
Data Parallelism
Training
Now that we've sharded our dataset and initialized our model with synchronized weights, we can begin training it using Data Parallelism! Let's implement this step by step
set_seed() # Set the seed so the models initialize with the same parameters
train_dataloader = get_dataloader(train_shard)
batch = next(iter(train_dataloader)) # Get one batch from our dataset for testing
model = get_smol_model()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
The first step in training is to take one batch of data and have the model make a prediction on it.
batch = {k: v.to(device) for k, v in batch.items()}
output = model(**batch) # Make a prediction on one batch of data
If the batches are different, the resulting loss will be different on each GPU. If we backpropagate the gradients and update the parameters they will all be different! Now we are training multiple different models instead of multiple instances of the same model. Let's do this and see what happens.
output.loss.backward() # Backpropagate to calculate the gradients
optimizer.step() # Step the optimizer and update the parameters
Now when we check the parameters now we will see that they are not the same. To spread all of the gradient tensors to every GPU we can use the all_gather operation. Each GPU has a different set of gradients, since they trained on a different batch of data, and we want to gather all the gradients on all GPUs to check if they are the same.
Since the gradients are the same shape and data type we can allocate space on each GPU, as we did previously, for when we receive the data gathered from the other GPUs, and then assign the data into that allocated memory space when the data is received.
More information on all_gather can be found here
# Check the gradients
for p in model.parameters():
    if p.grad is not None:
        # Clone the local gradient
        local_grad = p.grad.clone()
        # Allocate one empty tensor with the same shape as the gradient for each GPU.
        gathered = [torch.empty_like(local_grad) for _ in range(dist.get_world_size())]
        # Each GPU will send the `local_grad` tensor and receive them from all other
        # GPUs in the `gathered` list
        dist.all_gather(gathered, local_grad)
        # Check every value in the `gathered` list
        for rank_idx, other_grad in enumerate(gathered):
            # Verify that the gathered gradient tensor is the same as our local gradient tensor
            if not torch.equal(local_grad, other_grad):
                raise ValueError("Model gradients are not in sync")
Show output
❌ Error on Rank 2: Model gradients are not in sync
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 282, in _execute_code_streaming
    tree = ast.parse(code.strip(), mode='eval')
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/ast.py", line 52, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 2
    for p in model.parameters():
    ^^^
SyntaxError: invalid syntax
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 367, in _execute_code_streaming
    exec(compile(tree, '<string>', 'exec'), self.namespace)
  File "<string>", line 18, in <module>
ValueError: Model gradients are not in sync
❌ Error on Rank 1: Model gradients are not in sync
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 282, in _execute_code_streaming
    tree = ast.parse(code.strip(), mode='eval')
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/ast.py", line 52, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 2
    for p in model.parameters():
    ^^^
SyntaxError: invalid syntax
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 367, in _execute_code_streaming
    exec(compile(tree, '<string>', 'exec'), self.namespace)
  File "<string>", line 18, in <module>
ValueError: Model gradients are not in sync
❌ Error on Rank 3: Model gradients are not in sync
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 282, in _execute_code_streaming
    tree = ast.parse(code.strip(), mode='eval')
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/ast.py", line 52, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 2
    for p in model.parameters():
    ^^^
SyntaxError: invalid syntax
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 367, in _execute_code_streaming
    exec(compile(tree, '<string>', 'exec'), self.namespace)
  File "<string>", line 18, in <module>
ValueError: Model gradients are not in sync
❌ Error on Rank 0: Model gradients are not in sync
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 282, in _execute_code_streaming
    tree = ast.parse(code.strip(), mode='eval')
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/ast.py", line 52, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 2
    for p in model.parameters():
    ^^^
SyntaxError: invalid syntax
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/nbdistributed/worker.py", line 367, in _execute_code_streaming
    exec(compile(tree, '<string>', 'exec'), self.namespace)
  File "<string>", line 18, in <module>
ValueError: Model gradients are not in sync
The errors in the output demonstrate that our gradients are not in sync!
So after we do backpropagation and calculate the gradients, but before we step the optimizer to update the parameters (and also before we zero the gradients!) we must synchronize the gradients across our models. If we step the optimizer to update the model's parameters after the gradients are in sync, the parameters will remain in sync as well.
Synchronizing the Gradients
Let's review, in order to correctly train a model across multiple GPUs while benefitting from Data Parallelism we must:
- Start all instances of the model with identical parameters.
- Compute gradients on different micro batches of data for each model instance.
- Synchronize gradients so that backpropagation and weight updates result in the same parameters across all models.
We want the gradients to all be the same but how do we correctly combine the gradient values? We simply take their average. But why the average?
Let's go back to the fundamentals and single GPU training momentarily, and take our intuition from batching. When we train on a single GPU with a batch of, for example, 32 items, we're computing the gradient for each example and then averaging those 32 gradients together before updating the parameters. By averaging our computed gradients, we are reducing their variance and aligning them more closely with the full gradient of our entire dataset. This averaged gradient gives us a more stable and representative direction to update our model parameters.
The principle is the same in distributed training! Instead of averaging gradients from 32 items in a mini-batch on one GPU, we're averaging gradients from different mini-batches across multiple GPUs.
If the values we want to average were all on one machine, in code we could use the reduce function and sum the values in the list to get the total, then divide by the total number of elements to obtain the average. But our values (the gradients) are distributed across multiple GPUs.
Here we can use another distributed operation, all_reduce with the ReduceOp.SUM to take the average of our values across GPUs. The all_reduce operation will put the result on every GPU, while running reduce will put the result on only one GPU; since our model is on every GPU we should use all_reduce.
# Same setup as before
set_seed()
train_dataloader = get_dataloader(train_shard)
batch = next(iter(train_dataloader))
model = get_smol_model()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
batch = {k: v.to(device) for k, v in batch.items()}
output = model(**batch) # Make a prediction
output.loss.backward() # Backpropagate to calculate the gradients
# Sync the gradients after backpropagatoin
for p in model.parameters():
    if p.grad is not None:
        # Run reduce on p.grad across and sum the values, send the result to all GPUs
        dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
        # Take the average by dividing the sum by the number of GPUs
        p.grad /= dist.get_world_size()
        # Check if gradients are the same as we did previously
        local_grad = p.grad.clone()
        gathered = [torch.empty_like(local_grad) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered, local_grad)
        for rank_idx, other_grad in enumerate(gathered):
            if not torch.equal(local_grad, other_grad):
                raise ValueError("Model gradients are not in sync")
Implementing Simple DDP
We have already implemented a simple version of Distributed Data Parallelism! Let's put the code we have so far into a class to see it.
class SimpleDDP:
    def __init__(self, model: torch.nn.Module):
        self.model = model
        # Verify our model parameters are in sync at initialization as we did before
        for p in model.parameters():
            rank0_param = p.data.clone()
            dist.broadcast(rank0_param, src=0)
            if not torch.equal(p.data, rank0_param):
                raise ValueError(
                    "Model parameters are not synced across ranks during __init__.",
                    "Use set_seed"
                )
    def sync_gradients(self):
        """
        Sync the gradients across all GPUs using all_reduce.
        Should be called before the backwards pass.
        """
        for p in self.model.parameters():
            if p.grad is not None:
                # Sum the gradient across all GPUs and take the average
                dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
                # Divide by the total number of GPUs to obtain the average
                p.grad /= dist.get_world_size()
    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Now let's setup a training loop so we can profile our Data Parallelism implementation and compare the difference in execution to the single GPU version.
set_seed()
train_dataloader = get_dataloader(train_shard)
model = get_smol_model()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# We wrap our model in the SimpleDDP class
ddp_model = SimpleDDP(model)
max_steps = 6
with get_profiler("/workspace/ddp/multi") as prof:
    for (i, batch) in enumerate(train_dataloader):
        if i >= max_steps:
            break
        batch = {k: v.to(device) for k, v in batch.items()}
        output = ddp_model(**batch)
        # Calculate the gradients
        output.loss.backward()
        # Sync the gradients after the backwards step!!
        ddp_model.sync_gradients()
        # Update and zero the gradients after synchronizing
        optimizer.step()
        optimizer.zero_grad()
        # step the profiler
        prof.step()
Profiling
Let's examine the trace for our model trainied using DDP. We will examine the trace for GPU rank 0. Similar to the single GPU trace, we see the forward and backward passes launching kernels from the CPU.
Something different in this trace are the pink blocks at the end of every "Profiler Step" in the GPU stream. Those blocks represent the all_reduce operations, but no other computation is occuring on the GPUs at that time. We can verify this in the embedded trace explorer below by hovering over the blocks and reading the tooltip, we will see that the pink blocks represent the all_reduce calls.
While our GPUs are waiting for the all_reduce to complete, they are not doing any computation. So even though we are training on more GPUs, our efficiency on each GPU is less compared to the single GPU training scenario. We can see this in the "Temporal Breakdown" chart which shows ~15% of our GPU time is not used for computations, and the "Computation-Communication Overlap" chart shows we have 0% overlap.
Show trace explorer
The trace files are loaded from my server and might be slow to load. This blog is hosted from a kubernetes cluster I run in my home, hopefully it doesn't blow up 😅. You can get a live view of what is running in my cluster on the home page!
For a better viewing experience, minimize the sidebar by clicking the "hamburger" (
) icon for a better view and click the "expand all" ( ) button on the left side a bit below. 
Gradient Accumulation
Communication Overhead
Time on GPUs is expensive, and as we saw above our GPUs are not doing any computation while they wait for the communication time and overhead required to synchronize parameters and gradients. However if we want to scale up our training we have no choice but to pay some communication cost since we need to share information across GPUs.
Can we improve on this? Is if it is necessary to sync our gradients after every step of the training loop?
We can borrow another idea from single GPU training, gradient accumulation. On Single GPUs we can use gradient accumulation to process several batches of data while continuously summing the gradients, without updating the parameters (stepping the optimizer). When the accumlation phase is done, the gradients are averaged and then the paramters are updated once.
This technique allows us to simulate training with larger batches than we may be able to fit in memory of our single GPU. On multiple GPUs we can apply gradient accumulation so we do not have to pay the communication cost of synchronizing the gradients as every batch of data is processed in the training loop. We add gradient accumulation to our SimpleDDP class in the code below.
class SimpleDDPWithGA:
    ...
    def sync_gradients(self):
        """
        Sync the gradients across all GPUs using all_reduce.
        Should be called before the backwards pass.
        """
        if not self.should_sync:
            return  # skip syncing
        for p in self.model.parameters():
            if p.grad is not None:
                # Sum the gradient across all GPUs and take the average
                dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
                # Divide by the total number of GPUs to obtain the average
                p.grad /= dist.get_world_size()
    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)
    @property
    def should_sync(self):
        return self._sync
    def enable_grad_sync(self):
        self._sync = True
    def disable_grad_sync(self):
        self._sync = False
Show full code
class SimpleDDPWithGA:
    def __init__(self, model: torch.nn.Module):
        self.model = model
        # Verify our model parameters are in sync at initialization as we did before
        for p in model.parameters():
            rank0_param = p.data.clone()
            dist.broadcast(rank0_param, src=0)
            if not torch.equal(p.data, rank0_param):
                raise ValueError(
                    "Model parameters are not synced across ranks during __init__.",
                    "Use set_seed"
                )
    def sync_gradients(self):
        """
        Sync the gradients across all GPUs using all_reduce.
        Should be called before the backwards pass.
        """
        if not self.should_sync:
            return  # skip syncing
        for p in self.model.parameters():
            if p.grad is not None:
                # Sum the gradient across all GPUs and take the average
                dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
                # Divide by the total number of GPUs to obtain the average
                p.grad /= dist.get_world_size()
    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)
    @property
    def should_sync(self):
        return self._sync
    def enable_grad_sync(self):
        self._sync = True
    def disable_grad_sync(self):
        self._sync = False
Now we repeat the process of setting up our model for training.
set_seed()
train_dataloader = get_dataloader(train_shard)
model = get_smol_model()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
ddp_model = SimpleDDPWithGA(model)
max_steps = 6
One last thing, we define grad_accum_steps to specify how often we want to sync our gradients!
grad_accum_steps = 3
We setup the profiler and run training loop the same way we did earlier with some slight changes.
We enable gradient syncing if we are at a sync step, and we sync the gradients, step the optimizer, and zero the gradients only during sync steps.
with get_profiler("/workspace/ddp/ga") as prof:
    for (i, batch) in enumerate(train_dataloader):
        if i >= max_steps:
            break
        batch = {k: v.to(device) for k, v in batch.items()}
        # Enable or disable syncing depending on the step number in the training loop
        if (i + 1) % grad_accum_steps == 0:
            ddp_model.enable_grad_sync()
        else:
            ddp_model.disable_grad_sync()
        output = ddp_model(**batch)
        # Calculate the gradients
        output.loss.backward()
        if ddp_model.should_sync:
            # If it is time to sync, sync the gradients after the backwards step
            ddp_model.sync_gradients()
            # Update and zero the gradients after synchronizing
            optimizer.step()
            optimizer.zero_grad()
        # step the profiler
        prof.step()
Profiling
In the below trace we still see the two pink blocks from when we are synchronizing our gradients during step 3 and step 6, the final step, of our training loop. However the number of gradient syncs is less than what we saw in our SimpleDDP class, and the GPU utilization is more continuous (though idle during synchronization) throughout training.
If we observe the "Temporal Breakdown" chart adding gradient accumulation improved our computation time significantly! We achieved a reduction of ~66%, with only ~5% non-compute time. This would add up significantly over a longer training run scaled up to thousands of steps. However the "Computation-Communication Overlap" chart still shows 0% overlap.
Show trace explorer
The trace files are loaded from my server and might be slow to load. This blog is hosted from a kubernetes cluster I run in my home, hopefully it doesn't blow up 😅. You can get a live view of what is running in my cluster on the home page!
For a better viewing experience, minimize the sidebar by clicking the "hamburger" (
) icon for a better view and click the "expand all" ( ) button on the left side a bit below. 
Overlapping Communication and Computation
Idle GPUs
While we have reduced how often we sync, our GPUs are not doing any computations while they wait for the gradient sync communications to finish. We can verify this by zooming into the pink communication blocks in our trace and observing that no compute operations occur during this period.
From here we can futher optimize by overlapping communication and computation. In our previous example we synchronized the gradients after completing the entire backwards pass. If we consider the order in which we compute gradients, backwards from the last layer to the first, we notice that we can start synchronizing gradients while we compute the gradients of the prior layer.
In PyTorch we can register a backwards hook to a tensor to run some code after the gradient with respect to the tensor is computed during the backwards pass. If we synchronize the gradients in the backawrds hook our GPUs can continue the backpropagation computations during that wait.
Broken down step by step:
- The gradient of the tensors in the last layer are computed during the backwards pass.
- The backward hook is called for the tensors in the last layer, triggering a sync for those gradients.
- The gradient for the prior layer's tesnsors are computed during the backward pass.
- The backward hook is called for that layer's tensors, triggering a sync for its gradients.
- We repeat until we have synced the gradients for the first layer.
Here is what it looks like in code.
class SimpleDDPHookGA:
    def __init__(self, model: torch.nn.Module, async_op=False):
        ...
        self._async_op = async_op
        self.register_backward_hook(self._sync_gradient)
    def _sync_gradient(self, grad):
        """
        Sync the gradient across all GPUs using all_reduce.
        """
        if not self.should_sync:
            return grad # skip syncing
        # Sum the gradient across all GPUs and take the average
        dist.all_reduce(grad, op=dist.ReduceOp.SUM, async_op=self._async_op)
        grad /= dist.get_world_size()
        return grad
    def register_backward_hook(self, hook_fn):
        self.sync_hooks = []
        for p in self.model.parameters():
            if p.requires_grad:
                h = p.register_hook(hook_fn)
                self.sync_hooks.append(h)
    ...
Show full code
class SimpleDDPHookGA:
    def __init__(self, model: torch.nn.Module):
        self.model = model
        for p in model.parameters():
            rank0_param = p.data.clone()
            dist.broadcast(rank0_param, src=0)
            if not torch.equal(p.data, rank0_param):
                raise ValueError(
                    "Model parameters are not synced across ranks during __init__.",
                    "Use set_seed"
                )
        self.register_backward_hook(self._sync_gradient)
    def _sync_gradient(self, grad):
        """
        Sync the gradient across all GPUs using all_reduce.
        """
        if not self.should_sync:
            return grad # skip syncing
        # Sum the gradient across all GPUs and take the average
        dist.all_reduce(grad, op=dist.ReduceOp.SUM)
        grad /= dist.get_world_size()
        return grad
    def register_backward_hook(self, hook_fn):
        self.sync_hooks = []
        for p in self.model.parameters():
            if p.requires_grad:
                h = p.register_hook(hook_fn)
                self.sync_hooks.append(h)
    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)
    @property
    def should_sync(self):
        return self._sync
    def enable_grad_sync(self):
        self._sync = True
    def disable_grad_sync(self):
        self._sync = False
Let's repeat the setup to profile and train our model from before so we can compare the results.
set_seed()
train_dataloader = get_dataloader(train_shard)
model = get_smol_model()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
ddp_model = SimpleDDPHookGA(model)
max_steps = 6
grad_accum_steps = 3
We setup the profiler and run training loop the same way we did earlier with some slight changes, we no longer manually initiate the gradients sync, the backwards hook will handle it for us.
with get_profiler("/workspace/ddp/hook-ga") as prof:
    for (i, batch) in enumerate(train_dataloader):
        if i >= max_steps:
            break
        batch = {k: v.to(device) for k, v in batch.items()}
        if (i + 1) % grad_accum_steps == 0:
            ddp_model.enable_grad_sync()
        else:
            ddp_model.disable_grad_sync()
        output = ddp_model(**batch)
        output.loss.backward()
        if ddp_model.should_sync:
            # We no longer sync the gradients explicitly, it is handled in the backwards hook
            optimizer.step()
            optimizer.zero_grad()
        prof.step()
Profiling
We can see the difference in the trace below, now the pink all_reduce blocks are interleaved with the GPU computations!
Show trace explorer
The trace files are loaded from my server and might be slow to load. This blog is hosted from a kubernetes cluster I run in my home, hopefully it doesn't blow up 😅. You can get a live view of what is running in my cluster on the home page!
For a better viewing experience, minimize the sidebar by clicking the "hamburger" (
) icon for a better view and click the "expand all" ( ) button on the left side a bit below. 
Asynchronous Communcation
If we look carefully however, we will notice that our communication is still blocking computation, so it is not overlapped but interleaved which was not the goal. We can confirm this in the "Computation-Communication Overlap" chart which still shows 0% overlap.
We have one more trick we can apply, we can run our communication operation (all_reduce) asynchronously. We can begin computing the next set of gradients while the current set are being synchronized.
Now let's try again with non-blocking communication by setting async_op=True, and running all_reduce asynchronously.
We can use the future API to add a callback, after_sync_cb to ensure we complete the averaging after the sync is done. We also create a helper method wait_for_gradients which calls .wait on each pending future. Unlike what you may expect, .wait does not block until GPU operations are complete, instead it blocks until the operation is enqueued to the GPU's CUDA stream. However once the operation is enqueued to the CUDA stream we can wait on it with torch.cuda.synchronize.
class SimpleDDPAsyncHookGA:
    def __init__(self, model: torch.nn.Module):
        self._sync = True
        self.pending_futures = []  # Store futures to wait on
        ...
        self.register_backward_hook(self._sync_gradient)
    def _sync_gradient(self, param):
        def hook(grad):
            if not self.should_sync:
                return grad # skip syncing
            # Sum the gradient across all GPUs and take the average
            work = dist.all_reduce(grad, op=dist.ReduceOp.SUM, async_op=True)
            def after_sync_cb(fut):
                # Average in-place once reduction finished
                ws = dist.get_world_size()
                param.grad.div_(ws)
                return None
            # Get the future and chain the callback
            future = work.get_future().then(after_sync_cb)
            # Store the future so we can wait on it later
            self.pending_futures.append(future)
            return grad
        return hook
    def register_backward_hook(self, hook_fn):
        self.sync_hooks = []
        for p in self.model.parameters():
            if p.requires_grad:
                h = p.register_hook(hook_fn(p))
                self.sync_hooks.append(h)
    def wait_for_gradients(self):
        for future in self.pending_futures:
            # Wait for gradient all_reduce operation to be
            # enqueued onto a CUDA stream
            future.wait()
        self.pending_futures.clear()
    ...
Show full code
class SimpleDDPAsyncHookGA:
    def __init__(self, model: torch.nn.Module):
        self.model = model
        self._sync = True
        self.pending_futures = []  # Store futures to wait on
        for p in model.parameters():
            rank0_param = p.data.clone()
            dist.broadcast(rank0_param, src=0)
            if not torch.equal(p.data, rank0_param):
                raise ValueError(
                    "Model parameters are not synced across ranks during __init__.",
                    "Use set_seed"
                )
        self.register_backward_hook(self._sync_gradient)
    def _sync_gradient(self, param):
        def hook(grad):
            if not self.should_sync:
                return grad # skip syncing
            # Sum the gradient across all GPUs and take the average
            work = dist.all_reduce(grad, op=dist.ReduceOp.SUM, async_op=True)
            def after_sync_cb(fut):
                # Average in-place once reduction finished
                ws = dist.get_world_size()
                param.grad.div_(ws)
                return None
            # Get the future and chain the callback
            future = work.get_future().then(after_sync_cb)
            # Store the future so we can wait on it later
            self.pending_futures.append(future)
            return grad
        return hook
    def register_backward_hook(self, hook_fn):
        self.sync_hooks = []
        for p in self.model.parameters():
            if p.requires_grad:
                h = p.register_hook(hook_fn(p))
                self.sync_hooks.append(h)
    def wait_for_gradients(self):
        # Wait for all pending async gradient operations to complete
        for future in self.pending_futures:
            future.wait()
        self.pending_futures.clear()
    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)
    @property
    def should_sync(self):
        return self._sync
    def enable_grad_sync(self):
        self._sync = True
    def disable_grad_sync(self):
        self._sync = False
set_seed()
train_dataloader = get_dataloader(train_shard)
model = get_smol_model()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
ddp_model = SimpleDDPAsyncHookGA(model)
max_steps = 6
grad_accum_steps = 3
We setup the profiler and run training loop on our SimpleDDPHookGA model the same way we did previously, except this time we are communicating asynchronously. To ensure our all_reduce operations are enqueued and complete, we call wait_for_gradients and torch.cuda.synchronize after the backwards pass.
with get_profiler("/workspace/ddp/hook-ga-async") as prof:
    for (i, batch) in enumerate(train_dataloader):
        if i >= max_steps:
            break
        batch = {k: v.to(device) for k, v in batch.items()}
        if (i + 1) % grad_accum_steps == 0:
            ddp_model.enable_grad_sync()
        else:
            ddp_model.disable_grad_sync()
        output = ddp_model(**batch)
        output.loss.backward()
        if ddp_model.should_sync:
            # We no longer sync the gradients explicitly, it is handled in the
            # backwards hook. We still need to wait for the async operations to
            # complete before moving to the next step
            ddp_model.wait_for_gradients()
            torch.cuda.synchronize()
            optimizer.step()
            optimizer.zero_grad()
        prof.step()
Profiling Async
There are some differences in the trace below compared to our previous example, there is now a second stream where our all_reduce operation executes without blocking the gradient computation. We can zoom into the GPU streams in the embedded trace viewer to see the difference between the blocking, synchronous all_reduce in the prevoius example and the non-blocking, asynchronous all_reduce in this example.
We also see some additional streams which shows our callback is fired and we complete the final step of averaging our gradients, dividing by the total number of GPUs.
Looking at the "Computation-Communication Overlap" chart we can confirm that we have finally, successfully overlapped computation and communication! Our "Temporal Breakdown" chart also shows only ~2% of the time which is a ~60% improvement from our previous implementation and an overall ~87% improvement!
Show trace explorer
The trace files are loaded from my server and might be slow to load. This blog is hosted from a kubernetes cluster I run in my home, hopefully it doesn't blow up 😅. You can get a live view of what is running in my cluster on the home page!
For a better viewing experience, minimize the sidebar by clicking the "hamburger" (
) icon for a better view and click the "expand all" ( ) button on the left side a bit below. 
Bucketing Gradients
Reducing Syncs
The last optimization we will implement is bucketing gradients. Rather than running multiple communication operations on smaller tensors, it is more efficient to run fewer communications on larger tensors.
We can apply this to our model by registering a different kind of backwards hook, this time on our autograd graph nodes. This will be executed anytime a gradient with respect to this node is taken. In our autograd graph, a node represents a function/math operation that returns a tensor and defines how to compute the gradient of its outputs with respect to its inputs. The node also stores a gradient function which is the corresponding backwards operation to the forward operation represented by our node. We will register our backwards hook to this gradient function.
Instead of all_reduce on individual parameters, we will collect them into a "bucket" represented as a list. Once the bucket is full we will flatten the gradients in the list into one tensor and all_reduce it. Once the all_reduce is complete, we will unflatten and average the gradients.
class SimpleDDPBucketGA:
    def __init__(self, model: torch.nn.Module, bucket_cap_mb=25):
        self.model = model
        self.bucket_cap_bytes = bucket_cap_mb * 1024 * 1024
        self._bucket = []
        self._bucket_size = 0
        self._register_hooks()
    def _register_hooks(self):
        self._bucket_hooks = []
        # Register hooks to collect gradients and synchronize in buckets
        for p in self.model.parameters():
            if p.requires_grad:
                # This is a trick which gives us access to grad_fn
                param_tmp = p.expand_as(p)
                # We get the node from the grad_fn of the param
                grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
                # Register the callback to the node to run in backwards hook
                # after the node's gradient is computed
                grad_acc_fn.register_hook(self._make_bucket_hook(p))
                self._bucket_hooks.append(grad_acc_fn)
    def _make_bucket_hook(self, param):
        def hook(*args):
            if param.grad is not None:
                if self._sync:
                    # Append the param to the bucket so it will be synced
                    self._bucket.append(param)
                    # Increase the bucket size counter
                    self._bucket_size += param.grad.numel() * param.grad.element_size()
                    # Sync if our bucket is full
                    if self._bucket_size >= self.bucket_cap_bytes:
                        self._sync_bucket()
        return hook
    def _sync_bucket(self):
        # Synchronize all gradients in the bucket using a coalesced reduce
        # Extract gradient tensors
        grads = [p.grad for p in self._bucket]
        if not grads:
            return
        # Coalesce gradients into a single contiguous tensor
        flat_grads = torch.cat([g.view(-1) for g in grads])
        # All-reduce the flattened tensor
        dist.all_reduce(flat_grads, op=dist.ReduceOp.SUM)
        # Average the gradients
        flat_grads /= dist.get_world_size()
        # Copy back gradients to their original parameters
        offset = 0
        for p in self._bucket:
            numel = p.grad.numel()
            flat_grad = flat_grads[offset:offset + numel]
            p.grad.data = flat_grad.view_as(p.grad)
            offset += numel
        self._bucket = []
        self._bucket_size = 0
    def flush_bucket(self):
        # Flush any remaining gradients in the bucket
        if self._bucket:
            self._sync_bucket()
    ...
Show full code
class SimpleDDPBucketGA:
    def __init__(self, model: torch.nn.Module, bucket_cap_mb=25):
        self.model = model
        self.bucket_cap_bytes = bucket_cap_mb * 1024 * 1024
        self._sync = True
        self._bucket = []
        self._bucket_size = 0
        self._register_hooks()
    def _register_hooks(self):
        self._bucket_hooks = []
        # Register hooks to collect gradients and synchronize in buckets
        for p in self.model.parameters():
            if p.requires_grad:
                # This is a trick which gives us access to grad_fn
                param_tmp = p.expand_as(p)
                # We get the node from the grad_fn of the param
                grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
                # Register the callback to the node to run in backwards hook
                # after the node's gradient is computed
                grad_acc_fn.register_hook(self._make_bucket_hook(p))
                self._bucket_hooks.append(grad_acc_fn)
    def _make_bucket_hook(self, param):
        def hook(*args):
            if param.grad is not None:
                if self._sync:
                    # Append the param to the bucket so it will be synced
                    self._bucket.append(param)
                    # Increase the bucket size counter
                    self._bucket_size += param.grad.numel() * param.grad.element_size()
                    # Sync if our bucket is full
                    if self._bucket_size >= self.bucket_cap_bytes:
                        self._sync_bucket()
        return hook
    def _sync_bucket(self):
        # Synchronize all gradients in the bucket using a coalesced reduce
        # Extract gradient tensors
        grads = [p.grad for p in self._bucket]
        if not grads:
            return
        # Coalesce gradients into a single contiguous tensor
        flat_grads = torch.cat([g.view(-1) for g in grads])
        # All-reduce the flattened tensor
        dist.all_reduce(flat_grads, op=dist.ReduceOp.SUM)
        # Average the gradients
        flat_grads /= dist.get_world_size()
        # Copy back gradients to their original parameters
        offset = 0
        for p in self._bucket:
            numel = p.grad.numel()
            flat_grad = flat_grads[offset:offset + numel]
            p.grad.data = flat_grad.view_as(p.grad)
            offset += numel
        self._bucket = []
        self._bucket_size = 0
    def flush_bucket(self):
        # Flush any remaining gradients in the bucket
        if self._bucket:
            self._sync_bucket()
    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)
    @property
    def should_sync(self):
        return self._sync
    def enable_grad_sync(self):
        self._sync = True
    def disable_grad_sync(self):
        self._sync = False
We initialize the model and dataloader.
set_seed()
train_dataloader = get_dataloader(train_shard)
model = get_smol_model()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
ddp_model = SimpleDDPBucketGA(model)
max_steps = 6
grad_accum_steps = 3
We setup the profiler and run training loop the same way we did earlier with some slight changes.
We add one additional sync call after the backwards pass to ensure any remaining gradients are synchronized.
with get_profiler("/workspace/ddp/bucket-ga") as prof:
    for (i, batch) in enumerate(train_dataloader):
        if i >= max_steps:
            break
        batch = {k: v.to(device) for k, v in batch.items()}
        if (i + 1) % grad_accum_steps == 0:
            ddp_model.enable_grad_sync()
        else:
            ddp_model.disable_grad_sync()
        output = ddp_model(**batch)
        output.loss.backward()
        if ddp_model.should_sync:
            # Sync one last time to ensure any remaining gradients are synced
            ddp_model.flush_bucket()
            optimizer.step()
            optimizer.zero_grad()
        prof.step()
Profiling
Looking at the screenshot below we can see our all_reduce represented by the pink intermittent blocks. This can be confirmed by hovering over those blocks in the embedded trace.
We have confirmed that our gradient synchronization is now chunked during the backwards pass. However the "Computation-Communication Overlap" chart shows that once again the synchronization is blocking computation. Let's make it async!
Show trace explorer
The trace files are loaded from my server and might be slow to load. This blog is hosted from a kubernetes cluster I run in my home, hopefully it doesn't blow up 😅. You can get a live view of what is running in my cluster on the home page!
For a better viewing experience, minimize the sidebar by clicking the "hamburger" (
) icon for a better view and click the "expand all" ( ) button on the left side a bit below. 
Asynchronous Communcation
Now let's try again with non-blocking communication by setting async_op=True, and running all_reduce on the bucketed gradients asynchronously.
We add a callback, after_sync_cb to ensure we complete the averaging after the gradient sync is complete. Similar to our previous async hook example we add some helpers to ensure the operations are enqueued to the GPU CUDA stream.
class SimpleDDPAsyncBucketGA:
    def __init__(self, model: torch.nn.Module, bucket_cap_mb=25, async_op=False):
        ...
        # List of pending Futures to wait on
        self._pending_futures = []
        ...
    def _sync_bucket(self):
        ...
        bucket_params = list(self._bucket)  # snapshot before clearing
        # Clear bucket state early so hooks can refill while async op in flight
        self._bucket = []
        self._bucket_size = 0
        # Coalesce gradients into a single contiguous tensor
        flat_grads = torch.cat([g.view(-1) for g in grads])
        # Launch async all-reduce on the coalesced tensor
        work = dist.all_reduce(flat_grads, op=dist.ReduceOp.SUM, async_op=True)
        # Register callback to run after gradient sync
        def _after_sync_cb(fut):
            # Average and copy back once reduction finished
            flat_grads.div_(dist.get_world_size())
            # Copy back gradients
            offset = 0
            for p in bucket_params:
                numel = p.grad.numel()
                flat_grad = flat_grads[offset:offset + numel]
                p.grad.copy_(flat_grad.view_as(p.grad))
                offset += numel
            return None
        # Get the future, chain the callback, and store the resulting future
        future = work.get_future().then(_after_sync_cb)
        self._pending_futures.append(future)
    def wait_for_gradients(self):
        # Wait for all pending async gradient operations to complete
        for future in self._pending_futures:
            future.wait()
        self._pending_futures.clear()
    ...
Show full code
class SimpleDDPAsyncBucketGA:
    def __init__(self, model: torch.nn.Module, bucket_cap_mb=25):
        self.model = model
        self.bucket_cap_bytes = bucket_cap_mb * 1024 * 1024
        self._sync = True
        self._bucket = []
        self._bucket_size = 0
        # List of pending Futures to wait on
        self._pending_futures = []
        self._register_hooks()
    def _register_hooks(self):
        self._bucket_hooks = []
        # Register hooks to collect gradients and synchronize in buckets
        for p in self.model.parameters():
            if p.requires_grad:
                # Trick to get access to grad_fn (retain graph node)
                param_tmp = p.expand_as(p)
                grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
                grad_acc_fn.register_hook(self._make_bucket_hook(p))
                self._bucket_hooks.append(grad_acc_fn)
    def _make_bucket_hook(self, param):
        def hook(*args):
            if param.grad is not None and self._sync:
                # Append the param to the bucket so it will be synced
                self._bucket.append(param)
                # Increase the bucket size counter
                self._bucket_size += param.grad.numel() * param.grad.element_size()
                # Sync if bucket full
                if self._bucket_size >= self.bucket_cap_bytes:
                    self._sync_bucket()
        return hook
    def _sync_bucket(self):
        # Synchronize all gradients in the bucket using a coalesced reduce
        grads = [p.grad for p in self._bucket]
        if not grads:
            return
        bucket_params = list(self._bucket)  # snapshot before clearing
        # Clear bucket state early so hooks can refill while async op in flight
        self._bucket = []
        self._bucket_size = 0
        # Coalesce gradients into a single contiguous tensor
        flat_grads = torch.cat([g.view(-1) for g in grads])
        # Launch async all-reduce on the coalesced tensor
        work = dist.all_reduce(flat_grads, op=dist.ReduceOp.SUM, async_op=True)
        # Register callback to run after gradient sync
        def _after_sync_cb(fut):
            # Average and copy back once reduction finished
            flat_grads.div_(dist.get_world_size())
            # Copy back gradients
            offset = 0
            for p in bucket_params:
                numel = p.grad.numel()
                flat_grad = flat_grads[offset:offset + numel]
                p.grad.copy_(flat_grad.view_as(p.grad))
                offset += numel
            return None
        # Get the future, chain the callback, and store the resulting future
        future = work.get_future().then(_after_sync_cb)
        self._pending_futures.append(future)
    def flush_bucket(self):
        # Flush any remaining gradients in the bucket
        if self._bucket:
            self._sync_bucket()
    def wait_for_gradients(self):
        # Wait for all pending async gradient operations to complete
        for future in self._pending_futures:
            future.wait()
        self._pending_futures.clear()
    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)
    @property
    def should_sync(self):
        return self._sync
    def enable_grad_sync(self):
        self._sync = True
    def disable_grad_sync(self):
        self._sync = False
set_seed()
train_dataloader = get_dataloader(train_shard)
model = get_smol_model()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
ddp_model = SimpleDDPAsyncBucketGA(model)
max_steps = 6
grad_accum_steps = 3
We setup the profiler and run training loop on our SimpleDDPAsyncBucketGA model the same way we did previously, except this time we are communicating asynchronously.
with get_profiler("/workspace/ddp/bucket-ga-async") as prof:
    for (i, batch) in enumerate(train_dataloader):
        if i >= max_steps:
            break
        batch = {k: v.to(device) for k, v in batch.items()}
        if (i + 1) % grad_accum_steps == 0:
            ddp_model.enable_grad_sync()
        else:
            ddp_model.disable_grad_sync()
        output = ddp_model(**batch)
        output.loss.backward()
        if ddp_model.should_sync:
            # Flush any remaining gradients in the bucket
            ddp_model.flush_bucket()
            # Wait for all async gradient sync operations to complete
            ddp_model.wait_for_gradients()
            torch.cuda.synchronize()
            optimizer.step()
            optimizer.zero_grad()
        prof.step()
Profiling Async
Similarly to our hook example when we compared synchronous and asynchronous communications, we see the second stream where our all_reduce operation executes without blocking the gradient computation. Like before, we can zoom into the GPU streams in the embedded trace viewer to see the difference between the blocking, synchronous all_reduce in the prevoius example and the non-blocking, asynchronous all_reduce in this example.
We also see some additional streams which shows our callback is fired and we complete the final step of averaging our gradients, dividing by the total number of GPUs.
While the "Computation-Communication Overlap" chart shows improvement from the synchronous version and, at least with this configuration of four Nvidia T4 GPUs, we do not see much difference compared to the non-bucketed version. It could be that we have reduced the number of synchronizations enough with gradient accumulation. Implementing the bucketed version was a good exercise, but the data shows that in practice the additional complexity may not be worth it if we continue with gradient accumulation and this GPU configuration.
Show trace explorer
The trace files are loaded from my server and might be slow to load. This blog is hosted from a kubernetes cluster I run in my home, hopefully it doesn't blow up 😅. You can get a live view of what is running in my cluster on the home page!
For a better viewing experience, minimize the sidebar by clicking the "hamburger" (
) icon for a better view and click the "expand all" ( ) button on the left side a bit below. 
Conclusion
Looking back at our work, we've built Distributed Data Parallelism and experimented with improving performance by adding gradient accumulation, overlapping communication and computation with backward hooks, and bucketing gradients to reduce the number of syncs. We also visualized how operations executed on the GPUs in all cases and compared the difference with single GPU training and how it changed while using synchronous and asynchronous communications.
Although our first attempt at implementing DDP resulted in 0% computation-communication overlap and ~15% non-compute time, we ultimately achieved a reduction of ~87% in non-compute time and ~90% overlap. We now have the foundations and skills to train a model using multiple GPUs! (As long as our model fits on a single GPU!)
 
             
             
             
             
             
            