Having demonstrated that the enformer-predicted epigenome is a good enough predictor of the aracena epigenome, we can now continue the training process, but using the entire training corpus and for longer (more epochs).
import pandas as pdmapping_table = pd.read_csv("/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/index_files/remapped_table_filt.csv", index_col=0)
Define functions for metric reporting during training
Define other utility functions for training
Code
def group_sampler(groups):import random group_shuffled = groups[:] random.shuffle(group_shuffled)return group_shuffled
Main training loop
Code
epochs =2sub_epochs =3lr =0.001model_type ="cnn"individual_cur ="AF20"groups = ["train"+str(x) for x inrange(1,19)]# Load model device = ("cuda"if torch.cuda.is_available()else"mps"if torch.backends.mps.is_available()else"cpu")if model_type =="linear": cur_model = ref_to_aracena_models.RefToAracenaMLP(basenji_tracks=1, num_aracena_tracks=1, hidden_layer_dims=[])elif model_type =="cnn": cur_model = ref_to_aracena_models.RefToAracenaCNN(basenji_tracks=1, num_aracena_tracks=1)print(cur_model.cpu())cur_model = cur_model.to(device)# Save model paramsoutdir =f"/beagle3/haky/users/saideep/projects/aracena_modeling/training_runs/ref_to_aracena_{model_type}_lr{lr}_se{sub_epochs}"ifnot os.path.exists(outdir): os.makedirs(outdir)# Logginglogfile = os.path.join(outdir, "log.txt")ifnot os.path.exists(logfile):withopen(logfile, "w") as f: f.write("\t".join(["Epoch","Group","SubEpoch","time","iter" ])+"\n")# Load validation datasetenformer_predictions_hdf5_path =f"/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/"aracena_hdf5_path =f"/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training/"validation_loader = load_group_dataset(mapping_table, "valid", individual_cur, os.path.join(enformer_predictions_hdf5_path,"ref-epigenome_valid_aracena.h5"), os.path.join(aracena_hdf5_path,f"{individual_cur}_valid_aracena.h5"), batch_size=256)# Set up optimizeroptimizer = torch.optim.Adam( cur_model.parameters(), lr=lr)criterion = torch.nn.PoissonNLLLoss(log_input=True, reduction="none")# Tracking lossrun_iter =0start_time = time.perf_counter()loss_tracker = []valid_loss = {}trainloader =Nonefor epoch inrange(epochs):print("Epoch:",epoch) epoch_start = time.perf_counter()# The training set is split into "groups" in which a new trainloader is created for each group because not all groups fit into memory at once.for group in group_sampler(groups):del trainloader trainloader = load_group_dataset(mapping_table, group, individual_cur, os.path.join(enformer_predictions_hdf5_path,f"ref-epigenome_{group}_aracena.h5"), os.path.join(aracena_hdf5_path,f"{individual_cur}_{group}_aracena.h5"))# In order to reduce overhead due to loading a new group one after another, we can cycle through each group multiple times before moving on to the next group. # I call these cycles "sub-epochs". They are akin to single epochs from beforefor sub_epoch inrange(sub_epochs):for j, batch inenumerate(trainloader):if j%500==0:print(f"Epoch:{epoch} Group:{group} SubEpoch:{sub_epoch} time:{time.perf_counter()-epoch_start} iter:{j}")withopen(logfile, "a") as f: f.write("\t".join([str(x) for x in [ epoch,group,sub_epoch,(time.perf_counter()-epoch_start),j ]])+"\n")if torch.sum(torch.isnan(batch["ref_targets"])) >0:# nan_count += 1continueelse:# good_count += 1pass optimizer.zero_grad()input= batch["ref_targets"] target = batch["aracena_targets"].to(device) if model_type =="linear":input=input.to(device) target = target.to(device)elif model_type =="cnn":input=input.swapaxes(1,2).to(device) target = target.swapaxes(1,2).to(device) model_predictions = cur_model(input)# if j==0:# create_tracks_for_plot(batch, input, target, model_predictions, model_type)# plot_tracks(tracks_for_plot, interval_for_plot) loss = criterion(model_predictions, target)if torch.sum(torch.isnan(loss)) >0:print(input)print(target)print(loss)print("epoch_",epoch, "_iter_",j)continue mean_loss =float(loss.mean())if mean_loss <0:print("Negative loss",resized_int) loss_tracker.append(float(loss.mean())) run_iter +=1 loss.mean().backward() optimizer.step()# Calculate validation loss after each sub_epochfor j, batch inenumerate(validation_loader):input= batch["ref_targets"] target = batch["aracena_targets"].to(device) if model_type =="linear":input=input.to(device) target = target.to(device)elif model_type =="cnn":input=input.swapaxes(1,2).to(device) target = target.swapaxes(1,2).to(device) model_predictions = cur_model(input) loss = criterion(model_predictions, target) valid_loss[run_iter] =float(loss.mean()) outfile = os.path.join(outdir, f"ref_to_aracena_{model_type}_lr{lr}_se{sub_epochs}_epoch{epoch}.pt") torch.save({'epoch': epoch,'model_state_dict': cur_model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': torch.Tensor(loss_tracker) }, outfile)