Training predicted epigenome to aracena model using full dataset

Author

Saideep Gona

Published

October 2, 2023

Context

Having demonstrated that the enformer-predicted epigenome is a good enough predictor of the aracena epigenome, we can now continue the training process, but using the entire training corpus and for longer (more epochs).

For interactive GPU session on beagle:

sinteractive –account=pi-haky –partition=beagle3 –gres=gpu:1 –mem=120GB –time=1:00:00 conda activate /beagle3/haky/users/shared_software/dl-tools export LD_LIBRARY_PATH=\(LD_LIBRARY_PATH:/beagle3/haky/users/shared_software/dl-tools/lib my_ip_address=\)( /sbin/ip route get 8.8.8.8 | awk ’{print \(7;exit}' ) jupyter-notebook --no-browser --ip=\)my_ip_address –port=15005

Code
import h5py
import numpy as np
import torch
import os,sys
import time
import copy

import kipoiseq

sys.path.append('/beagle3/haky/users/saideep/github_repos/Daily-Blog-Sai/posts/2023-10-02-train_pred_to_aracena_full')

import ref_to_aracena_models
# importlib.reload(ref_to_aracena_models)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

Define functions for data loading

Code
class HDF5DatasetRefAracena(Dataset):

    def __init__(self, enformer_ref_targets, aracena_targets, mapping_table, individual, summarization="mean", bin_size=128):
        super().__init__()

        import pandas as pd

        self.bin_size = bin_size
        self.summarization = summarization
        self.individual = individual
        self.mapping_table = mapping_table
        self.files = {}
        self.regions = {}

        if summarization == "mean":
            with h5py.File(enformer_ref_targets, 'r') as ert:
                self.files['ref_targets'] = torch.Tensor(ert['ref_epigenome'][:,:,:].swapaxes(0,1)).float()
                self.regions['ref_targets'] = ert['regions'][:,:,:]
            with h5py.File(aracena_targets, 'r') as at:
                self.files['aracena_targets'] = torch.Tensor(at['targets'][:,:,:].swapaxes(0,1)).float()
                self.regions['aracena_targets'] = at['regions'][:,:,:]

        elif summarization == "sum":
            with h5py.File(enformer_ref_targets, 'r') as ert:
                self.files['ref_targets'] = torch.mul(torch.Tensor(ert['ref_epigenome'][:,:,:].swapaxes(0,1)).float(),bin_size)
                self.regions['ref_targets'] = ert['regions'][:,:,:]
            with h5py.File(aracena_targets, 'r') as at:
                self.files['aracena_targets'] = torch.mul(torch.Tensor(at['targets'][:,:,:].swapaxes(0,1)).float(), bin_size)           
                self.regions['aracena_targets'] = at['regions'][:,:,:]
    def __getitem__(self, index):

        region = self.mapping_table.columns[index]

        dt = {}
        dt['ref_targets'] = self.files['ref_targets'][: , :, int(self.mapping_table.loc["ref-epigenome", region])].swapaxes(0,1)
        dt['aracena_targets'] = self.files['aracena_targets'][:, :,  int(self.mapping_table.loc[self.individual, region])].swapaxes(0,1)
        dt['ref_region'] = self.regions['ref_targets'][0,:,int(self.mapping_table.loc["ref-epigenome", region])]
        dt['ara_region'] = self.regions['aracena_targets'][0,:,int(self.mapping_table.loc[self.individual, region])]
                                
        return dt

    def __len__(self):
        return len(self.mapping_table.columns)
Code
def load_group_dataset(mapping_table, group, individual_cur, enformer_pred_path, aracena_hdf5_path, batch_size=32):

    mapping_table_subgroup = mapping_table.loc[:, mapping_table.loc["group"] == group]

    dataset = HDF5DatasetRefAracena(enformer_pred_path, aracena_hdf5_path, mapping_table_subgroup, individual_cur, summarization="mean")

    trainloader = DataLoader(
        dataset=dataset,
        batch_size=32,
        shuffle=True,
        drop_last=True,
        # num_workers=DISTENV['world_size'],
        # collate_fn=concat_pairs,
        pin_memory=torch.cuda.is_available(),
    )
    
    return trainloader

Load mapping table

Code
import pandas as pd

mapping_table = pd.read_csv("/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/index_files/remapped_table_filt.csv", index_col=0)

Define functions for metric reporting during training

Define other utility functions for training

Code
def group_sampler(groups):
    import random
    
    group_shuffled = groups[:]
    random.shuffle(group_shuffled)
    
    return group_shuffled

Main training loop

Code
epochs = 2
sub_epochs = 3
lr = 0.001
model_type = "cnn"
individual_cur = "AF20"

groups = ["train"+str(x) for x in range(1,19)]

# Load model 

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

if model_type == "linear":
    cur_model = ref_to_aracena_models.RefToAracenaMLP(basenji_tracks=1, num_aracena_tracks=1, hidden_layer_dims=[])
elif model_type == "cnn":
    cur_model = ref_to_aracena_models.RefToAracenaCNN(basenji_tracks=1, num_aracena_tracks=1)
    
print(cur_model.cpu())

cur_model = cur_model.to(device)

# Save model params

outdir = f"/beagle3/haky/users/saideep/projects/aracena_modeling/training_runs/ref_to_aracena_{model_type}_lr{lr}_se{sub_epochs}"
if not os.path.exists(outdir):
    os.makedirs(outdir)
    
# Logging

logfile = os.path.join(outdir, "log.txt")

if not os.path.exists(logfile):
    with open(logfile, "w") as f:
        f.write("\t".join([
                "Epoch","Group","SubEpoch","time","iter"
            ])+"\n")


# Load validation dataset

enformer_predictions_hdf5_path = f"/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/"

aracena_hdf5_path = f"/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/"

validation_loader = load_group_dataset(mapping_table, 
                                       "valid", 
                                       individual_cur, 
                                       os.path.join(enformer_predictions_hdf5_path,"ref-epigenome_valid_aracena.h5"), 
                                       os.path.join(aracena_hdf5_path,f"{individual_cur}_valid_aracena.h5"), 
                                       batch_size=256)
# Set up optimizer

optimizer = torch.optim.Adam(
        cur_model.parameters(),
        lr=lr)

criterion = torch.nn.PoissonNLLLoss(log_input=True, reduction="none")

# Tracking loss

run_iter = 0

start_time = time.perf_counter()

loss_tracker = []
valid_loss = {}

trainloader = None

for epoch in range(epochs):
    print("Epoch:",epoch)
    epoch_start = time.perf_counter()
    
    # The training set is split into "groups" in which a new trainloader is created for each group because not all groups fit into memory at once.
    for group in group_sampler(groups):
        
        del trainloader
        trainloader = load_group_dataset(mapping_table, group, individual_cur, 
                                        os.path.join(enformer_predictions_hdf5_path,f"ref-epigenome_{group}_aracena.h5"), 
                                        os.path.join(aracena_hdf5_path,f"{individual_cur}_{group}_aracena.h5"))
        
        # In order to reduce overhead due to loading a new group one after another, we can cycle through each group multiple times before moving on to the next group. 
        # I call these cycles "sub-epochs". They are akin to single epochs from  before
        for sub_epoch in range(sub_epochs):
            
            for j, batch in enumerate(trainloader):
                if j%500 == 0:
                    print(f"Epoch:{epoch} Group:{group} SubEpoch:{sub_epoch} time:{time.perf_counter()-epoch_start} iter:{j}")
                    with open(logfile, "a") as f:
                        f.write("\t".join([str(x) for x in [
                            epoch,group,sub_epoch,(time.perf_counter()-epoch_start),j
                        ]])+"\n")
                if torch.sum(torch.isnan(batch["ref_targets"])) > 0:
                    # nan_count += 1
                    continue
                else:
                    # good_count += 1
                    pass

                optimizer.zero_grad()

                input = batch["ref_targets"]

                target = batch["aracena_targets"].to(device) 


                if model_type == "linear":
                    input = input.to(device)
                    target = target.to(device)
                elif model_type == "cnn":
                    input = input.swapaxes(1,2).to(device)
                    target = target.swapaxes(1,2).to(device)

                model_predictions = cur_model(input)

                # if j==0:
                #     create_tracks_for_plot(batch, input, target, model_predictions, model_type)
                #     plot_tracks(tracks_for_plot, interval_for_plot)

                loss = criterion(model_predictions, target)

                if torch.sum(torch.isnan(loss)) > 0:
                    print(input)
                    print(target)
                    print(loss)
                    print("epoch_",epoch, "_iter_",j)
                    continue
                mean_loss = float(loss.mean())
                if mean_loss < 0:
                    print("Negative loss",resized_int)

                loss_tracker.append(float(loss.mean()))
                run_iter += 1
                loss.mean().backward()

                optimizer.step()
                
            # Calculate validation loss after each sub_epoch
            
            for j, batch in enumerate(validation_loader):
                
                input = batch["ref_targets"]

                target = batch["aracena_targets"].to(device) 

                if model_type == "linear":
                    input = input.to(device)
                    target = target.to(device)
                elif model_type == "cnn":
                    input = input.swapaxes(1,2).to(device)
                    target = target.swapaxes(1,2).to(device)

                model_predictions = cur_model(input)

                loss = criterion(model_predictions, target)
        
                valid_loss[run_iter] = float(loss.mean())
                
    outfile = os.path.join(outdir, f"ref_to_aracena_{model_type}_lr{lr}_se{sub_epochs}_epoch{epoch}.pt")
    torch.save({
        'epoch': epoch,
        'model_state_dict': cur_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': torch.Tensor(loss_tracker)
    }, outfile)
RefToAracenaCNN(
  (model): Sequential(
    (0): Sequential(
      (0): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): GELU()
      (2): Conv1d(1, 24, kernel_size=(4,), stride=(2,), padding=(1,))
      (3): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): GELU()
      (5): Conv1d(24, 24, kernel_size=(4,), stride=(2,), padding=(1,))
      (6): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): GELU()
      (8): Conv1d(24, 24, kernel_size=(4,), stride=(2,), padding=(1,))
    )
    (1): Sequential(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): Linear(in_features=2688, out_features=896, bias=True)
      (2): Softplus(beta=1, threshold=20)
      (3): Unflatten(dim=1, unflattened_size=(1, 896))
    )
  )
)
Epoch: 0
Epoch:0 Group:train5 SubEpoch:0 time:149.58251763135195 iter:0
Code
import matplotlib.pyplot as plt
import seaborn as sns
# plt.plot(loss_tracker)
sns.lineplot(loss_tracker)
<Axes: >