ref_to_aracena_mlp

Author

Saideep Gona

Published

September 5, 2023

Training a predicted epigenome to Aracena model

Load in some example input data

Code
import h5py
import numpy as np
import torch
import os,sys
import time

individual_cur = 'AF20'
group = 'train18'

aracena_hdf5_path = f"/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/{individual_cur}_{group}_aracena.h5"

enformer_pred_path = f"/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/ref-epigenome_{group}_aracena.h5"

# with h5py.File(enformer_pred_path, "r") as f:
#     print(list(f.keys()))
#     print(f["ref_epigenome"].shape)
#     example_preds = f["ref_epigenome"][:,:,0:1]

Need to find a common mapping of training examples because some aracena targets are missing

Code
# sequences_bed = "/project2/haky/saideep/aracena_data_preprocessing/test_conversion/sequences_chunked.bed"

# train1_regions_ref = []
# with open(sequences_bed, "r") as f:
#     for line in f:
#         p_line = line.strip().split("\t")
#         if p_line[3] == "train1":
#             train1_regions_ref.append("_".join(p_line[0:3]))

# print(train1_regions_ref[0:10])

# with h5py.File(aracena_hdf5_path, "r") as f:
#     ara_regions = f["regions"][0,:,:]

# train1_regions_ara = []
# for x in range(ara_regions.shape[1]):
#     # print(ara_regions[:,x])
#     train1_regions_ara.append("_".join(["chr"+str(ara_regions[0,x]),str(ara_regions[1,x]),str(ara_regions[2,x])]))

# print(train1_regions_ara[0:10])


# index_mapping = {}

# ref_iter = 0
# for i in range(len(train1_regions_ara)):

#     while train1_regions_ara[i] != train1_regions_ref[ref_iter]:
#         ref_iter += 1
        
#     index_mapping[i] = ref_iter

# print(index_mapping)

In a subsequent blog post, I formalized the above into a table which maps specific individual-region combinations to corresponding indices in the HDF5 files. I also filtered for missing data, which clogs the training process. Here I will load the table:

Code
import pandas as pd

mapping_table = pd.read_csv("/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/index_files/remapped_table_filt.csv", index_col=0)
Code
mapping_table
18_928386_1059458 4_113630947_113762019 11_18427720_18558792 16_85805681_85936753 3_158386188_158517260 8_132158314_132289386 16_24521594_24652666 8_18647448_18778520 4_133441845_133572917 17_74167260_74298332 ... 11_46314072_46445144 19_19876046_20007118 11_37827086_37958158 11_27046320_27177392 3_183471775_183602847 3_184159909_184290981 11_29225411_29356483 19_33204702_33335774 19_30681544_30812616 2_129664471_129795543
EU05 0 1 2 3 4 6 8 9 10 11 ... 1887 1888 1891 1894 1896 1897 1902 1907 1909 1911
AF28 0 1 2 3 4 6 8 9 10 11 ... 1887 1888 1891 1894 1896 1897 1902 1907 1909 1911
AF30 0 1 2 3 4 6 8 9 10 11 ... 1887 1888 1891 1894 1896 1897 1902 1907 1909 1911
EU33 0 1 2 3 4 6 8 9 10 11 ... 1887 1888 1891 1894 1896 1897 1902 1907 1909 1911
AF16 0 1 2 3 4 6 8 9 10 11 ... 1887 1888 1891 1894 1896 1897 1902 1907 1909 1911
EU47 0 1 2 3 4 6 8 9 10 11 ... 1887 1888 1891 1894 1896 1897 1902 1907 1909 1911
AF20 0 1 2 3 4 6 8 9 10 11 ... 1887 1888 1891 1894 1896 1897 1902 1907 1909 1911
ref-epigenome 0 1 2 3 4 5 6 7 8 9 ... 1077 1078 1079 1081 1083 1084 1087 1090 1091 1092
group train1 train1 train1 train1 train1 train1 train1 train1 train1 train1 ... test test test test test test test test test test

9 rows × 22698 columns

Load in MLP model. Note that a fully connected network directly relating basenji targets to aracena targets would use 500 billion parameters. Adding a 100 node hidden layer reduces this to 500 million parameters.

UPDATE: Not planning to use fully connected layers. Linear layers can also be constructed as simple linear transformations with channels connected independently.

Also, a basic CNN seems to train better than a flat linear layer.

Code
import os,sys, importlib
sys.path.append("/beagle3/haky/users/saideep/github_repos/Daily-Blog-Sai/posts/2023-09-05-ref_to_aracena_mlp")


import ref_to_aracena_models
importlib.reload(ref_to_aracena_models)
import torch
from torch import nn
from torch.utils.data import DataLoader


device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(device)

cur_model = ref_to_aracena_models.RefToAracenaMLP(hidden_layer_dims=[])
cpu
Code
sys.path.append("/beagle3/haky/users/saideep/github_repos/shared_pipelines/enformer_pipeline_pytorch/scripts/modules")
import enformer_pytorch

cur_model = enformer_pytorch.Enformer()
Code
mem_params = sum([param.nelement()*param.element_size() for param in cur_model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in cur_model.buffers()])
mem = mem_params + mem_bufs
print("Estimated model size (GB):",mem*(1e-9))



nb_acts = 0
def count_output_act(m, input, output):
    global nb_acts
    nb_acts += output.nelement()

for module in cur_model.modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.BatchNorm2d):
        module.register_forward_hook(count_output_act)


nb_params = sum([p.nelement() for p in cur_model.parameters()])
print("Number of parameters:",nb_params)
Estimated model size (GB): 1.00497948
Number of parameters: 251218220
Code
cur_model.cpu()
# torch.from_numpy(example_preds[:,:,0]).to(device).unsqueeze(0).swapaxes(1,2).shape
RefToAracenaMLP(
  (model): Sequential(
    (0): Linear(in_features=5313, out_features=12, bias=True)
    (1): Softplus(beta=1, threshold=20)
    (2): Dropout(p=0.1, inplace=False)
  )
)
Code
!nvidia-smi
Tue Sep 19 15:11:35 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 470.82.01    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-PCI...  On   | 00000000:B1:00.0 Off |                    0 |
| N/A   27C    P0    34W / 250W |      3MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
Code
# import time

# cur_model = cur_model.to(device)

# times = []
# for rep in range(10):
#     start = time.time()
#     ex_out = cur_model(torch.from_numpy(example_preds[:,:,0]).to(device).unsqueeze(0).swapaxes(1,2).to(device))
#     end = time.time()
#     times.append(end-start)
# print(ex_out.shape)
# print("Average forward pass time (s):",np.mean(times))

It therefore requires only 0.15 seconds of CPU time for a forward pass, much faster than Enformer.

We can now try to do some training of the model

Code
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pandas as pd
Code
class LazyHDF5DatasetRefAracena(Dataset):

    def __init__(self, enformer_ref_targets, aracena_targets, aracena_ref_index_mapping):
        super().__init__()

        self.index_mapping = aracena_ref_index_mapping
        self.files = {}
        self.files['ref_targets'] = h5py.File(enformer_ref_targets, 'r')
        self.files['aracena_targets'] = h5py.File(aracena_targets, 'r')
    
    def __getitem__(self, index):

        dt = {}
        dt['ref_targets'] = self.files['ref_targets']['ref_epigenome'][: , :, self.index_mapping[index]].swapaxes(0,1).float()
        dt['aracena_targets'] = self.files['aracena_targets']['targets'][:, :, index].swapaxes(0,1).float()

        return dt

    def __len__(self):
        return len(self.index_mapping.keys())

class HDF5DatasetRefAracena(Dataset):

    def __init__(self, enformer_ref_targets, aracena_targets, mapping_table, individual):
        super().__init__()

        import pandas as pd

        self.individual = individual
        self.mapping_table = mapping_table
        self.files = {}
        with h5py.File(enformer_ref_targets, 'r') as ert:
            self.files['ref_targets'] = torch.Tensor(ert['ref_epigenome'][:,:,:].swapaxes(0,1)).float()
        with h5py.File(aracena_targets, 'r') as at:
            self.files['aracena_targets'] = torch.Tensor(at['targets'][:,:,:].swapaxes(0,1)).float()
    
    def __getitem__(self, index):

        region = self.mapping_table.columns[index]

        dt = {}
        dt['ref_targets'] = self.files['ref_targets'][: , :, int(self.mapping_table.loc["ref-epigenome", region])].swapaxes(0,1)
        dt['aracena_targets'] = self.files['aracena_targets'][:, :,  int(self.mapping_table.loc[self.individual, region])].swapaxes(0,1)

        return dt

    def __len__(self):
        return len(self.mapping_table.columns)
    
    

Training test framework with just a single data partition

Code
mapping_table_subgroup = mapping_table.loc[:, mapping_table.loc["group"] == group]

dataset = HDF5DatasetRefAracena(enformer_pred_path, aracena_hdf5_path, mapping_table_subgroup, individual_cur)

trainloader = DataLoader(
    dataset=dataset,
    batch_size=1,
    shuffle=True,
    drop_last=True,
    # num_workers=DISTENV['world_size'],
    # collate_fn=concat_pairs,
    pin_memory=torch.cuda.is_available(),
)

# for j,batch in enumerate(trainloader):
#     print(j)
#     print(batch['ref_targets'].shape)
#     print(batch['aracena_targets'].shape)
Code
# with h5py.File(enformer_pred_path, "r") as f:
#     print(list(f.keys()))
#     print(f["ref_epigenome"].shape)
#     print(f["ref_epigenome"][0:10,:,:].shape)

# import h5py
# import sys
# enformer_pred_path = "/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/train1_ref.h5"

# with h5py.File(enformer_pred_path, "r") as f:
#     print(list(f.keys()))
#     print(f["ref_epigenome"].shape)
#     x = f["ref_epigenome"][:,:,:]

# print(sys.getsizeof(x)*1e-9)
    

It seems like it will require a minimum of 40GB to load 2000 training examples into memory at a time. If I do not load in everything, the data load times are very high (2+ seconds per target set).

Therefore, use: sinteractive –account=pi-haky –mem=60GB –partition=bigmem

Moving on to the main training loop

Code
epochs = 500
lr = 0.000001

cur_model = ref_to_aracena_models.RefToAracenaMLP(hidden_layer_dims=[])
cur_model = cur_model.to(device)

optimizer = torch.optim.Adam(
        cur_model.parameters(),
        lr=lr)

criterion = torch.nn.PoissonNLLLoss(log_input=True,reduction="none")

run_iter = 0

loss_tracker = []

nan_count = 0
good_count = 0

for epoch in range(epochs):
    print(epoch)
    epoch_start = time.perf_counter()
    for j, batch in enumerate(trainloader):
        # print(j)
        if torch.sum(torch.isnan(batch["ref_targets"])) > 0:
            nan_count += 1
            continue
        else:
            good_count += 1

        
        optimizer.zero_grad()

        input = batch["ref_targets"].to(device)
        # print(input.shape)
        # print(input)

        target = batch["aracena_targets"].to(device)
        # print(target.shape)
        # print(target)


        model_predictions = cur_model(input)

        loss = criterion(model_predictions, target)
        # print(loss.shape)
        # print(loss)
        
        # print(torch.sum(torch.isnan(loss)))


        if torch.sum(torch.isnan(loss)) > 0:
            print(input)
            print(target)
            print(loss)
            print("epoch_",epoch, "_iter_",j)
            continue
        mean_loss = float(loss.mean())
        loss_tracker.append(float(loss.mean()))

        loss.mean().backward()
        optimizer.step()
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
Code
outfile = f"/beagle3/haky/users/saideep/projects/aracena_modeling/saved_models/{group}_ref_to_aracena_linear_poissonlog_ep{epoch}_lr{lr}.pt"
torch.save({
    'epoch': epoch,
    'model_state_dict': cur_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': torch.Tensor(loss_tracker)
}, outfile)
Code
# if not torch.sum(torch.isnan(batch["ref_targets"])):
#     print("yay")

# print(loss_tracker[1:10])
# print(nan_count, good_count)
print(j)
print(len(loss_tracker))
print(input.shape)
print(nan_count, good_count)
print(len(trainloader))
mapping_table_subgroup.shape
12
130
torch.Size([1, 896, 5313])
0 130
13
(9, 13)

It seems that close to a third of the reference predictions contain NaN, which is not great. I need to do some pre-filtering to handle this.

Code
import matplotlib.pyplot as plt
import seaborn as sns
# plt.plot(loss_tracker)
sns.lineplot(loss_tracker[-20:])
# plt.ylim(0,4)
<Axes: >

After testing different configurations, it seems that poisson loss with log input is required for stable training curves. This makes sense due to the nature of the tracks

Training with full dataset and batching with DDP

Having tested a basic training setup and observing successful training, I am now ready to scale up the training. This involves a few notable changes:

1.) To span the full training set, data must be loaded into memory at different intervals. This means that in each epoch data must be loaded each time a new chunk is needed. Given the required loading time, it may be worth chunking the training into “sub-epochs” where multiple epochs are run for a given chunk before switching out.

2.) Utilizing DDP to parallelize the training to 4 GPUs on a single node with batching. Can speed up the training while efficiently utilizing a single node’s resources

3.) Need to also add in validation steps during the training

Code
def load_single_dataset(enformer_pred_path, aracena_hdf5_path, mapping table):
    
    dataset = HDF5DatasetRefAracena(enformer_pred_path, aracena_hdf5_path, mapping_table)

    trainloader = DataLoader(
        dataset=dataset,
        batch_size=1,
        shuffle=True,
        drop_last=True,
        # num_workers=DISTENV['world_size'],
        # collate_fn=concat_pairs,
        pin_memory=torch.cuda.is_available(),
    )

    return trainloader

# dataset = HDF5DatasetRefAracena(enformer_pred_path, aracena_hdf5_path, mapping_table)

# trainloader = DataLoader(
#     dataset=dataset,
#     batch_size=1,
#     shuffle=True,
#     drop_last=True,
#     # num_workers=DISTENV['world_size'],
#     # collate_fn=concat_pairs,
#     pin_memory=torch.cuda.is_available(),
# )