import h5pyimport numpy as npimport torchimport os,sysimport timeindividual_cur ='AF20'group ='train18'aracena_hdf5_path =f"/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/{individual_cur}_{group}_aracena.h5"enformer_pred_path =f"/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/ref-epigenome_{group}_aracena.h5"# with h5py.File(enformer_pred_path, "r") as f:# print(list(f.keys()))# print(f["ref_epigenome"].shape)# example_preds = f["ref_epigenome"][:,:,0:1]
Need to find a common mapping of training examples because some aracena targets are missing
Code
# sequences_bed = "/project2/haky/saideep/aracena_data_preprocessing/test_conversion/sequences_chunked.bed"# train1_regions_ref = []# with open(sequences_bed, "r") as f:# for line in f:# p_line = line.strip().split("\t")# if p_line[3] == "train1":# train1_regions_ref.append("_".join(p_line[0:3]))# print(train1_regions_ref[0:10])# with h5py.File(aracena_hdf5_path, "r") as f:# ara_regions = f["regions"][0,:,:]# train1_regions_ara = []# for x in range(ara_regions.shape[1]):# # print(ara_regions[:,x])# train1_regions_ara.append("_".join(["chr"+str(ara_regions[0,x]),str(ara_regions[1,x]),str(ara_regions[2,x])]))# print(train1_regions_ara[0:10])# index_mapping = {}# ref_iter = 0# for i in range(len(train1_regions_ara)):# while train1_regions_ara[i] != train1_regions_ref[ref_iter]:# ref_iter += 1# index_mapping[i] = ref_iter# print(index_mapping)
In a subsequent blog post, I formalized the above into a table which maps specific individual-region combinations to corresponding indices in the HDF5 files. I also filtered for missing data, which clogs the training process. Here I will load the table:
Code
import pandas as pdmapping_table = pd.read_csv("/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/index_files/remapped_table_filt.csv", index_col=0)
Code
mapping_table
18_928386_1059458
4_113630947_113762019
11_18427720_18558792
16_85805681_85936753
3_158386188_158517260
8_132158314_132289386
16_24521594_24652666
8_18647448_18778520
4_133441845_133572917
17_74167260_74298332
...
11_46314072_46445144
19_19876046_20007118
11_37827086_37958158
11_27046320_27177392
3_183471775_183602847
3_184159909_184290981
11_29225411_29356483
19_33204702_33335774
19_30681544_30812616
2_129664471_129795543
EU05
0
1
2
3
4
6
8
9
10
11
...
1887
1888
1891
1894
1896
1897
1902
1907
1909
1911
AF28
0
1
2
3
4
6
8
9
10
11
...
1887
1888
1891
1894
1896
1897
1902
1907
1909
1911
AF30
0
1
2
3
4
6
8
9
10
11
...
1887
1888
1891
1894
1896
1897
1902
1907
1909
1911
EU33
0
1
2
3
4
6
8
9
10
11
...
1887
1888
1891
1894
1896
1897
1902
1907
1909
1911
AF16
0
1
2
3
4
6
8
9
10
11
...
1887
1888
1891
1894
1896
1897
1902
1907
1909
1911
EU47
0
1
2
3
4
6
8
9
10
11
...
1887
1888
1891
1894
1896
1897
1902
1907
1909
1911
AF20
0
1
2
3
4
6
8
9
10
11
...
1887
1888
1891
1894
1896
1897
1902
1907
1909
1911
ref-epigenome
0
1
2
3
4
5
6
7
8
9
...
1077
1078
1079
1081
1083
1084
1087
1090
1091
1092
group
train1
train1
train1
train1
train1
train1
train1
train1
train1
train1
...
test
test
test
test
test
test
test
test
test
test
9 rows × 22698 columns
Load in MLP model. Note that a fully connected network directly relating basenji targets to aracena targets would use 500 billion parameters. Adding a 100 node hidden layer reduces this to 500 million parameters.
UPDATE: Not planning to use fully connected layers. Linear layers can also be constructed as simple linear transformations with channels connected independently.
Also, a basic CNN seems to train better than a flat linear layer.
mem_params =sum([param.nelement()*param.element_size() for param in cur_model.parameters()])mem_bufs =sum([buf.nelement()*buf.element_size() for buf in cur_model.buffers()])mem = mem_params + mem_bufsprint("Estimated model size (GB):",mem*(1e-9))nb_acts =0def count_output_act(m, input, output):global nb_acts nb_acts += output.nelement()for module in cur_model.modules():ifisinstance(module, nn.Conv2d) orisinstance(module, nn.Linear) orisinstance(module, nn.BatchNorm2d): module.register_forward_hook(count_output_act)nb_params =sum([p.nelement() for p in cur_model.parameters()])print("Number of parameters:",nb_params)
Estimated model size (GB): 1.00497948
Number of parameters: 251218220
# with h5py.File(enformer_pred_path, "r") as f:# print(list(f.keys()))# print(f["ref_epigenome"].shape)# print(f["ref_epigenome"][0:10,:,:].shape)# import h5py# import sys# enformer_pred_path = "/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/train1_ref.h5"# with h5py.File(enformer_pred_path, "r") as f:# print(list(f.keys()))# print(f["ref_epigenome"].shape)# x = f["ref_epigenome"][:,:,:]# print(sys.getsizeof(x)*1e-9)
It seems like it will require a minimum of 40GB to load 2000 training examples into memory at a time. If I do not load in everything, the data load times are very high (2+ seconds per target set).
# if not torch.sum(torch.isnan(batch["ref_targets"])):# print("yay")# print(loss_tracker[1:10])# print(nan_count, good_count)print(j)print(len(loss_tracker))print(input.shape)print(nan_count, good_count)print(len(trainloader))mapping_table_subgroup.shape
12
130
torch.Size([1, 896, 5313])
0 130
13
(9, 13)
It seems that close to a third of the reference predictions contain NaN, which is not great. I need to do some pre-filtering to handle this.
Code
import matplotlib.pyplot as pltimport seaborn as sns# plt.plot(loss_tracker)sns.lineplot(loss_tracker[-20:])# plt.ylim(0,4)
<Axes: >
After testing different configurations, it seems that poisson loss with log input is required for stable training curves. This makes sense due to the nature of the tracks
Training with full dataset and batching with DDP
Having tested a basic training setup and observing successful training, I am now ready to scale up the training. This involves a few notable changes:
1.) To span the full training set, data must be loaded into memory at different intervals. This means that in each epoch data must be loaded each time a new chunk is needed. Given the required loading time, it may be worth chunking the training into “sub-epochs” where multiple epochs are run for a given chunk before switching out.
2.) Utilizing DDP to parallelize the training to 4 GPUs on a single node with batching. Can speed up the training while efficiently utilizing a single node’s resources
3.) Need to also add in validation steps during the training