Implementing single node DDP training loop

Author

Saideep Gona

Published

September 25, 2023

Code
import h5py
import numpy as np
import torch
import os,sys
import time
import copy

import kipoiseq

sys.path.append('/beagle3/haky/users/saideep/github_repos/Daily-Blog-Sai/posts/2023-09-25-consolidate_aracena_training')

import ref_to_aracena_models
# importlib.reload(ref_to_aracena_models)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
Code
RefToAracenaMLP(
  (model): Sequential(
    (0): BatchNorm1d(5313, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Linear(in_features=5313, out_features=1024, bias=True)
    (2): Softplus(beta=1, threshold=20)
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): Softplus(beta=1, threshold=20)
    (6): Dropout(p=0.1, inplace=False)
    (7): Linear(in_features=512, out_features=256, bias=True)
    (8): Softplus(beta=1, threshold=20)
    (9): Dropout(p=0.1, inplace=False)
    (10): Linear(in_features=256, out_features=256, bias=True)
    (11): Softplus(beta=1, threshold=20)
    (12): Dropout(p=0.1, inplace=False)
  )
)

Context

After successful implementation of a basic training scheme, and demonstration of similarly basic training on subsets of training dataset, it is now time to refocus on the training loop for doing more efficient training on the full dataset. One part of this is implementing DDP to take advantage of all the GPUs on a given compute node, which is faster and tends to be more efficient per RAM utilized which in turn translates to better use of SUs on clusters such as beagle. The following blog is very helpful:

https://medium.com/codex/a-comprehensive-tutorial-to-pytorch-distributeddataparallel-1f4b42bb1b51

sinteractive –account=pi-haky –partition=beagle3 –gres=gpu:4 –time=2:00:00 conda activate /beagle3/haky/users/temi/software/conda_envs/dl-tools export LD_LIBRARY_PATH=\(LD_LIBRARY_PATH:/beagle3/haky/users/temi/software/conda_envs/dl-tools/lib my_ip_address=\)( /sbin/ip route get 8.8.8.8 | awk ’{print \(7;exit}' ) jupyter-notebook --no-browser --ip=\)my_ip_address –port=15005

On Beagle, if you allocate 4 gpus then you automatically reserve 250 GB of memory as well.

Modify dataloader for parallelization

Code
from torch.utils.data.distributed import DistributedSampler

class HDF5DatasetRefAracena(Dataset):

    def __init__(self, enformer_ref_targets, aracena_targets, mapping_table, individual, summarization="mean", bin_size=128):
        super().__init__()

        import pandas as pd

        self.bin_size = bin_size
        self.summarization = summarization
        self.individual = individual
        self.mapping_table = mapping_table
        self.files = {}
        self.regions = {}

        if summarization == "mean":
            with h5py.File(enformer_ref_targets, 'r') as ert:
                self.files['ref_targets'] = torch.Tensor(ert['ref_epigenome'][:,:,:].swapaxes(0,1)).float()
                self.regions['ref_targets'] = ert['regions'][:,:,:]
            with h5py.File(aracena_targets, 'r') as at:
                self.files['aracena_targets'] = torch.Tensor(at['targets'][:,:,:].swapaxes(0,1)).float()
                self.regions['aracena_targets'] = at['regions'][:,:,:]

        elif summarization == "sum":
            with h5py.File(enformer_ref_targets, 'r') as ert:
                self.files['ref_targets'] = torch.mul(torch.Tensor(ert['ref_epigenome'][:,:,:].swapaxes(0,1)).float(),bin_size)
                self.regions['ref_targets'] = ert['regions'][:,:,:]
            with h5py.File(aracena_targets, 'r') as at:
                self.files['aracena_targets'] = torch.mul(torch.Tensor(at['targets'][:,:,:].swapaxes(0,1)).float(), bin_size)           
                self.regions['aracena_targets'] = at['regions'][:,:,:]
    def __getitem__(self, index):

        region = self.mapping_table.columns[index]

        dt = {}
        dt['ref_targets'] = self.files['ref_targets'][: , :, int(self.mapping_table.loc["ref-epigenome", region])].swapaxes(0,1)
        dt['aracena_targets'] = self.files['aracena_targets'][:, :,  int(self.mapping_table.loc[self.individual, region])].swapaxes(0,1)
        dt['ref_region'] = self.regions['ref_targets'][0,:,int(self.mapping_table.loc["ref-epigenome", region])]
        dt['ara_region'] = self.regions['aracena_targets'][0,:,int(self.mapping_table.loc[self.individual, region])]
                                
        return dt

    def __len__(self):
        return len(self.mapping_table.columns)


def prepare_ddp_dataloader(rank, world_size, path_to_input, path_to_targets, group, individual, batch_size=32, pin_memory=False, num_workers=0):
    
    dataset = HDF5DatasetRefAracena(enformer_pred_path, aracena_hdf5_path, mapping_table_subgroup, individual_cur, summarization="mean")
    
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
    
    dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
    
    return dataloader

Set up DDP process group

Code
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group(backend='nccl' if torch.cuda.is_available() else 'gloo', 
                                init_method='env://',
                            rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

Wrap model + training loop with DDP

Code
def run_training_loop(rank, world_size):
    
    print(f"Running DDP on rank {rank}.")
    
    # Run setup function
    
    setup(rank, world_size)
    
    # Data loading info
    
    group = "train18"
    individual_cur = "AF20"
    dataloader = prepare_ddp_dataloader(rank, world_size, batch_size=32, pin_memory=False, num_workers=0)
    
    epochs = 100
    lr = 0.001
    model_type = "cnn"

    # wrap model with DDP wrapper
    if model_type == "linear":
        cur_model = ref_to_aracena_models.RefToAracenaMLP(basenji_tracks=1, num_aracena_tracks=1, hidden_layer_dims=[])
        cur_model = cur_model.to(rank)
        ddp_model = DDP(cur_model, device_ids=[rank], output_device=rank)
    elif model_type == "cnn":
        cur_model = ref_to_aracena_models.RefToAracenaCNN(basenji_tracks=1, num_aracena_tracks=1)
        cur_model = cur_model.to(rank)
        dd_model = DDP(cur_model, device_ids=[rank], output_device=rank)
        
    print(cur_model.cpu())

    optimizer = torch.optim.Adam(
            cur_model.parameters(),
            lr=lr)

    criterion = torch.nn.PoissonNLLLoss(log_input=True, reduction="none")

    run_iter = 0

    loss_tracker = []
    loss_all = []

    nan_count = 0
    good_count = 0

    for epoch in range(epochs):
        print("Epoch:",epoch)
        epoch_start = time.perf_counter()
        
        dataloader.sampler.set_epoch(epoch)
        
        for j, batch in enumerate(trainloader):

            if torch.sum(torch.isnan(batch["ref_targets"])) > 0:
                nan_count += 1
                continue
            else:
                good_count += 1

            
            optimizer.zero_grad()
            
            # Functions for plotting inputs/outputs as needed
            
            # print("ref_regions:", batch['ref_region'])
            # print("ara_regions:", batch['ara_region'])
            region_numpy = batch["ref_region"][0].numpy()
            resized_int = kipoiseq.Interval(*["chr"+str(region_numpy[0]), region_numpy[1], region_numpy[2]]).resize(896*128)
            interval_for_plot = [f"Epoch:{epoch} "+resized_int.chrom,resized_int.start,resized_int.end]


            input = batch["ref_targets"]
            # print(input.shape)
            # print(input)

            target = batch["aracena_targets"].to(device) 
            # print(target.shape)
            # print(target)


            if model_type == "linear":
                input = input.to(device)
                target = target.to(device)
            elif model_type == "cnn":
                input = input.swapaxes(1,2)[:,4766,:].reshape((-1,1,896)).to(device)
                target = target.swapaxes(1,2)[:,0,:].reshape((-1,1,896)).to(device)
            # print(input.shape)
            model_predictions = cur_model(input)
            

            if j==0:
                print(resized_int)
                print("input shape",input.shape)
                print("target shape",target.shape)
                print("model_predictions shape",model_predictions.shape)
                if model_type == "linear":
                    tracks_for_plot = {
                        "ref": input[0,:,0].cpu().numpy(), 
                        "aracena": target[0,:,0].cpu().numpy(),
                        "pred": model_predictions[0,:,0].cpu().detach().numpy()
                    }
                elif model_type == "cnn":
                    tracks_for_plot = {
                        "ref": input[0,0,:].cpu().numpy(), 
                        "aracena": target[0,0,:].cpu().numpy(),
                        "pred": model_predictions[0,0,:].cpu().detach().numpy()
                    }
                    
                
                plot_tracks(tracks_for_plot, interval_for_plot)
                


            loss = criterion(model_predictions, target)


            if torch.sum(torch.isnan(loss)) > 0:
                print(input)
                print(target)
                print(loss)
                print("epoch_",epoch, "_iter_",j)
                continue
            mean_loss = float(loss.mean())
            if mean_loss < 0:
                print("Negative loss",resized_int)
            loss_all.append(loss[:,:,:].cpu().detach().numpy())
            loss_tracker.append(float(loss.mean()))
            

            loss.mean().backward()
            # torch.nn.utils.clip_grad_norm_(cur_model.parameters(), 1)
            optimizer.step()
    cleanup()

Run the DDP training

Code
import torch.multiprocessing as mp
if __name__ == '__main__':
    # suppose we have 3 gpus
    world_size = 4    
    mp.spawn(
        run_training_loop,
        args=(world_size),
        nprocs=world_size
    )
ProcessExitedException: process 1 terminated with exit code 1