retraining_pred_to_aracena_models

Author

Saideep Gona

Published

September 21, 2023

Training a predicted epigenome to Aracena model

Load in some example input data

For interactive GPU session on beagle:

sinteractive –account=pi-haky –partition=beagle3 –gres=gpu:1 –mem=60GB –time=2:00:00 conda activate /beagle3/haky/users/temi/software/conda_envs/dl-tools export LD_LIBRARY_PATH=\(LD_LIBRARY_PATH:/beagle3/haky/users/temi/software/conda_envs/dl-tools/lib my_ip_address=\)( /sbin/ip route get 8.8.8.8 | awk ’{print \(7;exit}' ) jupyter-notebook --no-browser --ip=\)my_ip_address –port=15005

Code
import h5py
import numpy as np
import torch
import os,sys
import time
import copy

import kipoiseq

individual_cur = 'AF20'
group = 'train18'

aracena_hdf5_path = f"/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/{individual_cur}_{group}_aracena.h5"

enformer_pred_path = f"/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/ref-epigenome_{group}_aracena.h5"

# with h5py.File(enformer_pred_path, "r") as f:
#     print(list(f.keys()))
#     print(f["ref_epigenome"].shape)
#     example_preds = f["ref_epigenome"][:,:,0:1]

Need to find a common mapping of training examples because some aracena targets are missing

Code
# sequences_bed = "/project2/haky/saideep/aracena_data_preprocessing/test_conversion/sequences_chunked.bed"

# train1_regions_ref = []
# with open(sequences_bed, "r") as f:
#     for line in f:
#         p_line = line.strip().split("\t")
#         if p_line[3] == "train1":
#             train1_regions_ref.append("_".join(p_line[0:3]))

# print(train1_regions_ref[0:10])

# with h5py.File(aracena_hdf5_path, "r") as f:
#     ara_regions = f["regions"][0,:,:]

# train1_regions_ara = []
# for x in range(ara_regions.shape[1]):
#     # print(ara_regions[:,x])
#     train1_regions_ara.append("_".join(["chr"+str(ara_regions[0,x]),str(ara_regions[1,x]),str(ara_regions[2,x])]))

# print(train1_regions_ara[0:10])


# index_mapping = {}

# ref_iter = 0
# for i in range(len(train1_regions_ara)):

#     while train1_regions_ara[i] != train1_regions_ref[ref_iter]:
#         ref_iter += 1
        
#     index_mapping[i] = ref_iter

# print(index_mapping)

In a subsequent blog post, I formalized the above into a table which maps specific individual-region combinations to corresponding indices in the HDF5 files. I also filtered for missing data, which clogs the training process. Here I will load the table:

Code
import pandas as pd

mapping_table = pd.read_csv("/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/index_files/remapped_table_filt.csv", index_col=0)
Code
mapping_table
18_928386_1059458 4_113630947_113762019 11_18427720_18558792 16_85805681_85936753 3_158386188_158517260 8_132158314_132289386 21_35639003_35770075 16_24521594_24652666 8_18647448_18778520 4_133441845_133572917 ... 14_30965924_31096996 11_29225411_29356483 14_82805352_82936424 14_44040470_44171542 15_25685343_25816415 19_33204702_33335774 14_41861379_41992451 19_30681544_30812616 14_61473198_61604270 2_129664471_129795543
EU47 0 1 2 3 4 6 7 8 9 10 ... 1899 1902 1903 1905 1906 1907 1908 1909 1910 1911
EU33 0 1 2 3 4 6 7 8 9 10 ... 1899 1902 1903 1905 1906 1907 1908 1909 1910 1911
AF28 0 1 2 3 4 6 7 8 9 10 ... 1899 1902 1903 1905 1906 1907 1908 1909 1910 1911
EU05 0 1 2 3 4 6 7 8 9 10 ... 1899 1902 1903 1905 1906 1907 1908 1909 1910 1911
AF20 0 1 2 3 4 6 7 8 9 10 ... 1919 1922 1923 1925 1926 1928 1929 1930 1931 1932
AF30 0 1 2 3 4 6 7 8 9 10 ... 1899 1902 1903 1905 1906 1907 1908 1909 1910 1911
AF16 0 1 2 3 4 6 7 8 9 10 ... 1899 1902 1903 1905 1906 1907 1908 1909 1910 1911
ref-epigenome 0 1 2 3 4 5 6 7 8 9 ... 1909 1912 1913 1915 1916 1918 1919 1920 1921 1922
group train1 train1 train1 train1 train1 train1 train1 train1 train1 train1 ... test test test test test test test test test test

9 rows × 26280 columns

Load in MLP model. Note that a fully connected network directly relating basenji targets to aracena targets would use 500 billion parameters. Adding a 100 node hidden layer reduces this to 500 million parameters.

UPDATE: Not planning to use fully connected layers. Linear layers can also be constructed as simple linear transformations with channels connected independently.

Also, a basic CNN seems to train better than a flat linear layer.

Code
import os,sys, importlib
sys.path.append("/beagle3/haky/users/saideep/github_repos/Daily-Blog-Sai/posts/2023-09-21-retraining_pred_to_aracena_models")


import ref_to_aracena_models
importlib.reload(ref_to_aracena_models)
import torch
from torch import nn
from torch.utils.data import DataLoader


device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(device)


model_type = "cnn"

if model_type == "linear":
    cur_model = ref_to_aracena_models.RefToAracenaMLP(hidden_layer_dims=[])
elif model_type == "cnn":
    cur_model = ref_to_aracena_models.RefToAracenaCNN()
cur_model = cur_model.to(device)
cuda
Code

mem_params = sum([param.nelement()*param.element_size() for param in cur_model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in cur_model.buffers()])
mem = mem_params + mem_bufs
print("Estimated model size (GB):",mem*(1e-9))

num_params = sum([param.nelement() for param in cur_model.parameters()])
print("Estimated model parqms:",num_params)
Estimated model size (GB): 0.117793224
Estimated model parqms: 29437578
Code
cur_model.cpu()
# torch.from_numpy(example_preds[:,:,0]).to(device).unsqueeze(0).swapaxes(1,2).shape
RefToAracenaCNN(
  (model): Sequential(
    (0): Sequential(
      (0): BatchNorm1d(5313, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): GELU()
      (2): Conv1d(5313, 24, kernel_size=(4,), stride=(2,), padding=(1,))
      (3): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): GELU()
      (5): Conv1d(24, 24, kernel_size=(4,), stride=(2,), padding=(1,))
      (6): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): GELU()
      (8): Conv1d(24, 24, kernel_size=(4,), stride=(2,), padding=(1,))
    )
    (1): Sequential(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): Linear(in_features=2688, out_features=10752, bias=True)
      (2): Softplus(beta=1, threshold=20)
      (3): Unflatten(dim=1, unflattened_size=(12, 896))
    )
  )
)
Code
!nvidia-smi
Sat Sep 30 13:05:57 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 470.82.01    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A40          On   | 00000000:17:00.0 Off |                    0 |
|  0%   35C    P0    74W / 300W |    795MiB / 45634MiB |      2%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A   3928254      C   ...s/dl-tools/bin/python3.11      793MiB |
+-----------------------------------------------------------------------------+
Code
# import time

# cur_model = cur_model.to(device)

# times = []
# for rep in range(10):
#     start = time.time()
#     ex_out = cur_model(torch.from_numpy(example_preds[:,:,0]).to(device).unsqueeze(0).swapaxes(1,2).to(device))
#     end = time.time()
#     times.append(end-start)
# print(ex_out.shape)
# print("Average forward pass time (s):",np.mean(times))

It therefore requires only 0.15 seconds of CPU time for a forward pass, much faster than Enformer.

We can now try to do some training of the model

Code
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pandas as pd
Code
class LazyHDF5DatasetRefAracena(Dataset):

    def __init__(self, enformer_ref_targets, aracena_targets, aracena_ref_index_mapping):
        super().__init__()

        self.index_mapping = aracena_ref_index_mapping
        self.files = {}
        self.files['ref_targets'] = h5py.File(enformer_ref_targets, 'r')
        self.files['aracena_targets'] = h5py.File(aracena_targets, 'r')
    
    def __getitem__(self, index):

        dt = {}
        dt['ref_targets'] = self.files['ref_targets']['ref_epigenome'][: , :, self.index_mapping[index]].swapaxes(0,1).float()
        dt['aracena_targets'] = self.files['aracena_targets']['targets'][:, :, index].swapaxes(0,1).float()

        return dt

    def __len__(self):
        return len(self.index_mapping.keys())

class HDF5DatasetRefAracena(Dataset):

    def __init__(self, enformer_ref_targets, aracena_targets, mapping_table, individual, summarization="mean", bin_size=128):
        super().__init__()

        import pandas as pd

        self.bin_size = bin_size
        self.summarization = summarization
        self.individual = individual
        self.mapping_table = mapping_table
        self.files = {}
        self.regions = {}

        if summarization == "mean":
            with h5py.File(enformer_ref_targets, 'r') as ert:
                self.files['ref_targets'] = torch.Tensor(ert['ref_epigenome'][:,:,:].swapaxes(0,1)).float()
                self.regions['ref_targets'] = ert['regions'][:,:,:]
            with h5py.File(aracena_targets, 'r') as at:
                self.files['aracena_targets'] = torch.Tensor(at['targets'][:,:,:].swapaxes(0,1)).float()
                self.regions['aracena_targets'] = at['regions'][:,:,:]

        elif summarization == "sum":
            with h5py.File(enformer_ref_targets, 'r') as ert:
                self.files['ref_targets'] = torch.mul(torch.Tensor(ert['ref_epigenome'][:,:,:].swapaxes(0,1)).float(),bin_size)
                self.regions['ref_targets'] = ert['regions'][:,:,:]
            with h5py.File(aracena_targets, 'r') as at:
                self.files['aracena_targets'] = torch.mul(torch.Tensor(at['targets'][:,:,:].swapaxes(0,1)).float(), bin_size)           
                self.regions['aracena_targets'] = at['regions'][:,:,:]
    def __getitem__(self, index):

        region = self.mapping_table.columns[index]

        dt = {}
        dt['ref_targets'] = self.files['ref_targets'][: , :, int(self.mapping_table.loc["ref-epigenome", region])].swapaxes(0,1)
        dt['aracena_targets'] = self.files['aracena_targets'][:, :,  int(self.mapping_table.loc[self.individual, region])].swapaxes(0,1)
        dt['ref_region'] = self.regions['ref_targets'][0,:,int(self.mapping_table.loc["ref-epigenome", region])]
        dt['ara_region'] = self.regions['aracena_targets'][0,:,int(self.mapping_table.loc[self.individual, region])]
                                
        return dt

    def __len__(self):
        return len(self.mapping_table.columns)
    
    

Training test framework with just a single data partition

Code
mapping_table_subgroup = mapping_table.loc[:, mapping_table.loc["group"] == group]

dataset = HDF5DatasetRefAracena(enformer_pred_path, aracena_hdf5_path, mapping_table_subgroup, individual_cur, summarization="mean")

trainloader = DataLoader(
    dataset=dataset,
    batch_size=1,
    shuffle=True,
    drop_last=True,
    # num_workers=DISTENV['world_size'],
    # collate_fn=concat_pairs,
    pin_memory=torch.cuda.is_available(),
)

print(len(trainloader))
# for j,batch in enumerate(trainloader):
#     print(j)
#     print(batch['ref_targets'].shape)
#     print(batch['aracena_targets'].shape)
17
Code
print(len(trainloader))
17
Code
# with h5py.File(enformer_pred_path, "r") as f:
#     print(list(f.keys()))
#     print(f["ref_epigenome"].shape)
#     print(f["ref_epigenome"][0:10,:,:].shape)

# import h5py
# import sys
# enformer_pred_path = "/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/train1_ref.h5"

# with h5py.File(enformer_pred_path, "r") as f:
#     print(list(f.keys()))
#     print(f["ref_epigenome"].shape)
#     x = f["ref_epigenome"][:,:,:]

# print(sys.getsizeof(x)*1e-9)
    
Code
def plot_tracks(tracks, interval, height=1.5):
  import matplotlib.pyplot as plt
  import seaborn as sns

  fig, axes = plt.subplots(len(tracks), 1, figsize=(20, height * len(tracks)), sharex=True)
  for ax, (title, y) in zip(axes, tracks.items()):
    ax.fill_between(np.linspace(interval[1], interval[2], num=len(y)), y)
    ax.set_title(title)
    sns.despine(top=True, right=True, bottom=True)
  ax.set_xlabel("_".join([str(x) for x in interval]))
  plt.tight_layout()

It seems like it will require a minimum of 40GB to load 2000 training examples into memory at a time. If I do not load in everything, the data load times are very high (2+ seconds per target set).

Therefore, use: sinteractive –account=pi-haky –mem=60GB –partition=bigmem

Moving on to the main training loop

Code
epochs = 100
lr = 0.001
model_type = "cnn"

if model_type == "linear":
    cur_model = ref_to_aracena_models.RefToAracenaMLP(basenji_tracks=1, num_aracena_tracks=1, hidden_layer_dims=[])
elif model_type == "cnn":
    cur_model = ref_to_aracena_models.RefToAracenaCNN(basenji_tracks=1, num_aracena_tracks=1)
    
print(cur_model.cpu())

cur_model = cur_model.to(device)

optimizer = torch.optim.Adam(
        cur_model.parameters(),
        lr=lr)

criterion = torch.nn.PoissonNLLLoss(log_input=True, reduction="none")

run_iter = 0

loss_tracker = []
loss_all = []

nan_count = 0
good_count = 0

for epoch in range(epochs):
    print("Epoch:",epoch)
    epoch_start = time.perf_counter()
    for j, batch in enumerate(trainloader):
        if j%500 == 0:
            print(j)
        # print(j)
        if torch.sum(torch.isnan(batch["ref_targets"])) > 0:
            nan_count += 1
            continue
        else:
            good_count += 1

        
        optimizer.zero_grad()
        
        # Functions for plotting inputs/outputs as needed
        
        # print("ref_regions:", batch['ref_region'])
        # print("ara_regions:", batch['ara_region'])
        region_numpy = batch["ref_region"][0].numpy()
        resized_int = kipoiseq.Interval(*["chr"+str(region_numpy[0]), region_numpy[1], region_numpy[2]]).resize(896*128)
        interval_for_plot = [f"Epoch:{epoch} "+resized_int.chrom,resized_int.start,resized_int.end]


        input = batch["ref_targets"]
        # print(input.shape)
        # print(input)

        target = batch["aracena_targets"].to(device) 
        # print(target.shape)
        # print(target)


        if model_type == "linear":
            input = input.to(device)
            target = target.to(device)
        elif model_type == "cnn":
            input = input.swapaxes(1,2)[:,4766,:].reshape((-1,1,896)).to(device)
            target = target.swapaxes(1,2)[:,0,:].reshape((-1,1,896)).to(device)
        # print(input.shape)
        model_predictions = cur_model(input)
        

        if j==0:
            print(resized_int)
            print("input shape",input.shape)
            print("target shape",target.shape)
            print("model_predictions shape",model_predictions.shape)
            if model_type == "linear":
                tracks_for_plot = {
                    "ref": input[0,:,0].cpu().numpy(), 
                    "aracena": target[0,:,0].cpu().numpy(),
                    "pred": model_predictions[0,:,0].cpu().detach().numpy()
                }
            elif model_type == "cnn":
                tracks_for_plot = {
                    "ref": input[0,0,:].cpu().numpy(), 
                    "aracena": target[0,0,:].cpu().numpy(),
                    "pred": model_predictions[0,0,:].cpu().detach().numpy()
                }
                
            
            plot_tracks(tracks_for_plot, interval_for_plot)
            


        loss = criterion(model_predictions, target)


        if torch.sum(torch.isnan(loss)) > 0:
            print(input)
            print(target)
            print(loss)
            print("epoch_",epoch, "_iter_",j)
            continue
        mean_loss = float(loss.mean())
        if mean_loss < 0:
            print("Negative loss",resized_int)
        loss_all.append(loss[:,:,:].cpu().detach().numpy())
        loss_tracker.append(float(loss.mean()))
        

        loss.mean().backward()
        # torch.nn.utils.clip_grad_norm_(cur_model.parameters(), 1)
        optimizer.step()
RefToAracenaCNN(
  (model): Sequential(
    (0): Sequential(
      (0): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): GELU()
      (2): Conv1d(1, 24, kernel_size=(4,), stride=(2,), padding=(1,))
      (3): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): GELU()
      (5): Conv1d(24, 24, kernel_size=(4,), stride=(2,), padding=(1,))
      (6): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): GELU()
      (8): Conv1d(24, 24, kernel_size=(4,), stride=(2,), padding=(1,))
    )
    (1): Sequential(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): Linear(in_features=2688, out_features=896, bias=True)
      (2): Softplus(beta=1, threshold=20)
      (3): Unflatten(dim=1, unflattened_size=(1, 896))
    )
  )
)
Epoch: 0
0
chr12:28860106-28974794:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 1
0
chr1:63870304-63984992:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 2
0
chr1:63870304-63984992:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 3
0
chr2:34718828-34833516:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 4
0
chr8:32497029-32611717:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 5
0
chr1:41041852-41156540:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 6
0
chr8:32497029-32611717:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 7
0
chr13:104599207-104713895:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 8
0
chr1:185738464-185853152:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 9
0
chr4:10446291-10560979:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 10
0
chr1:63870304-63984992:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 11
0
chr1:185738464-185853152:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 12
0
chr4:189012390-189127078:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 13
0
chr1:185738464-185853152:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 14
0
chr13:104599207-104713895:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 15
0
chr15:48009570-48124258:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 16
0
chr1:185738464-185853152:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 17
0
chr1:210600485-210715173:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 18
0
chr1:210600485-210715173:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 19
0
chr21:40173526-40288214:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 20
0
chr2:34718828-34833516:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 21
0
chr4:10446291-10560979:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 22
0
chr4:87530737-87645425:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 23
0
chr4:189012390-189127078:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 24
0
chr4:87530737-87645425:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 25
0
chr15:48009570-48124258:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 26
0
chr15:48009570-48124258:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 27
0
chr2:34718828-34833516:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 28
0
chr12:28860106-28974794:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 29
0
chr4:133843631-133958319:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 30
0
chr1:41041852-41156540:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 31
0
chr15:48009570-48124258:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 32
0
chr1:210600485-210715173:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 33
0
chr1:63870304-63984992:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 34
0
chr1:210600485-210715173:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 35
0
chr19:15542632-15657320:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 36
0
chr8:32497029-32611717:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 37
0
chr15:48009570-48124258:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 38
0
chr1:210600485-210715173:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 39
0
chr1:41041852-41156540:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 40
0
chr2:34718828-34833516:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 41
0
chr1:41041852-41156540:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 42
0
chr1:185738464-185853152:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 43
0
chr1:210600485-210715173:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 44
0
chr22:31219262-31333950:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 45
0
chr21:40173526-40288214:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 46
0
chr12:28860106-28974794:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 47
0
chr21:40173526-40288214:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 48
0
chr1:41041852-41156540:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 49
0
chr4:189012390-189127078:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 50
0
chr4:87530737-87645425:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 51
0
chr12:28860106-28974794:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 52
0
chr4:189012390-189127078:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 53
0
chr4:10446291-10560979:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 54
0
chr2:34718828-34833516:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 55
0
chr1:41041852-41156540:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 56
0
chr21:40173526-40288214:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 57
0
chr22:31219262-31333950:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 58
0
chr4:133843631-133958319:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 59
0
chr4:133843631-133958319:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 60
0
chr15:48009570-48124258:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 61
0
chr19:15542632-15657320:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 62
0
chr22:31219262-31333950:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 63
0
chr12:28860106-28974794:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 64
0
chr8:32497029-32611717:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 65
0
chr1:210600485-210715173:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 66
0
chr1:185738464-185853152:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 67
0
chr1:210600485-210715173:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 68
0
chr1:41041852-41156540:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 69
0
chr21:40173526-40288214:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 70
0
chr8:32497029-32611717:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 71
0
chr22:31219262-31333950:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 72
0
chr4:189012390-189127078:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 73
0
chr17:56791717-56906405:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 74
0
chr1:210600485-210715173:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 75
0
chr15:48009570-48124258:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 76
0
chr2:34718828-34833516:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 77
0
chr4:10446291-10560979:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 78
0
chr19:15542632-15657320:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 79
0
chr4:133843631-133958319:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 80
0
chr4:87530737-87645425:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 81
0
chr19:15542632-15657320:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 82
0
chr15:48009570-48124258:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 83
0
chr1:41041852-41156540:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 84
0
chr1:185738464-185853152:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 85
0
chr1:63870304-63984992:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 86
0
chr15:48009570-48124258:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 87
0
chr8:32497029-32611717:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 88
0
chr4:189012390-189127078:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 89
0
chr4:189012390-189127078:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 90
0
chr4:87530737-87645425:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 91
0
chr2:34718828-34833516:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 92
0
chr1:185738464-185853152:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 93
0
chr1:210600485-210715173:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 94
0
chr4:10446291-10560979:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 95
0
chr4:10446291-10560979:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 96
0
chr12:28860106-28974794:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 97
0
chr19:15542632-15657320:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 98
0
chr4:133843631-133958319:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])
Epoch: 99
0
chr4:189012390-189127078:.
input shape torch.Size([1, 1, 896])
target shape torch.Size([1, 1, 896])
model_predictions shape torch.Size([1, 1, 896])

Code
outfile = f"/beagle3/haky/users/saideep/projects/aracena_modeling/saved_models/{group}_ref_to_aracena_{model_type}_poissonlog_ep{epoch}_lr{lr}.pt"
torch.save({
    'epoch': epoch,
    'model_state_dict': cur_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': torch.Tensor(loss_tracker)
}, outfile)
print(outfile)
/beagle3/haky/users/saideep/projects/aracena_modeling/saved_models/train18_ref_to_aracena_cnn_poissonlog_ep99_lr0.01.pt
Code
# if not torch.sum(torch.isnan(batch["ref_targets"])):
#     print("yay")

# print(loss_tracker[1:10])
# print(nan_count, good_count)
print(j)
print(len(loss_tracker))
print(input.shape)
print(nan_count, good_count)
print(len(trainloader))
mapping_table_subgroup.shape
16
1700
torch.Size([1, 1, 896])
0 1700
17
(9, 17)

It seems that close to a third of the reference predictions contain NaN, which is not great. I need to do some pre-filtering to handle this.

Code
import matplotlib.pyplot as plt
import seaborn as sns
# plt.plot(loss_tracker)
sns.lineplot(loss_tracker)

# plt.ylim(0,4)
<Axes: >

Code
sns.lineplot(loss_tracker[-100:])
<Axes: >

After testing different configurations, it seems that poisson loss with log input is required for stable training curves. This makes sense due to the nature of the tracks

Training with full dataset and batching with DDP

Having tested a basic training setup and observing successful training, I am now ready to scale up the training. This involves a few notable changes:

1.) To span the full training set, data must be loaded into memory at different intervals. This means that in each epoch data must be loaded each time a new chunk is needed. Given the required loading time, it may be worth chunking the training into “sub-epochs” where multiple epochs are run for a given chunk before switching out.

2.) Utilizing DDP to parallelize the training to 4 GPUs on a single node with batching. Can speed up the training while efficiently utilizing a single node’s resources

3.) Need to also add in validation steps during the training

Code
def load_single_dataset(enformer_pred_path, aracena_hdf5_path, full_mapping_table, group, batch_size=64):
    
    dataset = HDF5DatasetRefAracena(enformer_pred_path, aracena_hdf5_path, mapping_table)

    trainloader = DataLoader(
        dataset=dataset,
        batch_size=1,
        shuffle=True,
        drop_last=True,
        # num_workers=DISTENV['world_size'],
        # collate_fn=concat_pairs,
        pin_memory=torch.cuda.is_available(),
    )

    return trainloader

# dataset = HDF5DatasetRefAracena(enformer_pred_path, aracena_hdf5_path, mapping_table)

# trainloader = DataLoader(
#     dataset=dataset,
#     batch_size=1,
#     shuffle=True,
#     drop_last=True,
#     # num_workers=DISTENV['world_size'],
#     # collate_fn=concat_pairs,
#     pin_memory=torch.cuda.is_available(),
# )