train_aracena_tss_predictors_evaluation

Author

Saideep Gona

Published

October 4, 2023

Context

Here I trained an elastic net and simple MLP mode to predict TSS RNAseq expression from Flu condition (aracena dataset) from Enformer predictions. Training was done using corpus of TSS sites from a ranmdom mixture of individuals. Here I am evaluating the performance of the trained models on a held out test set.

TODO: Average of individuals TODO: Collect RNAseq from Katie, and use this for DL-EN training AND training PrediXcan TODO: Try different MLP params

Code
suppressMessages(library(tidyverse))
suppressMessages(library(glue))
PRE = "/Users/saideepgona/Library/CloudStorage/Box-Box/imlab-data/data-Github/Daily-Blog-Sai"

## COPY THE DATE AND SLUG fields FROM THE HEADER
SLUG="train_aracena_tss_predictor" ## copy the slug from the header
bDATE='2023-10-04' ## copy the date from the blog's header here
DATA = glue("{PRE}/{bDATE}-{SLUG}")
if(!file.exists(DATA)) system(glue::glue("mkdir {DATA}"))
WORK=DATA
Code
library(httpgd)
library(glmnet)
library(glue)
library(data.table)
library(neuralnet)
library(dplyr)
Code
predict_on_data_subset <- function(data_subset, glmnet_model){

  # Load trained models
  glmnet_model <- readRDS(glmnet_model)
  nnmodel <- readRDS("/beagle3/haky/users/saideep/projects/aracena_modeling/elastic_net/trained_mlp.rds")

  png("/beagle3/haky/users/saideep/github_repos/Daily-Blog-Sai/posts/2023-10-04-train_aracena_tss_predictor/glmnet_lambdas.png")
  plot(glmnet_model)
  dev.off()

  # Load dataset to be predicted on
  data_X <- glue("/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training_tss_centered/new_{data_subset}_inputs.csv")
  data_y <- glue("/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training_tss_centered/new_{data_subset}_targets.csv")

  X_train <- as.matrix(data.table::fread(data_X))

  y_train <- as.data.frame(data.table::fread(data_y))

  y_train_for_cv <- y_train[,1]

  # Make predictions and associated plots

  glmnet_predictions <- predict(glmnet_model, X_train, s = "lambda.min", type="response")

  png(glue("/beagle3/haky/users/saideep/github_repos/Daily-Blog-Sai/posts/2023-10-04-train_aracena_tss_predictor/figures/{data_subset}_corr_en_pred.png"))
  plot(glmnet_predictions, y_train[,1], ylim=c(-1,20), ylab="ground truth flu rnaseq", xlab="glmnet predictions", main = glue("Glmnet predictions vs. ground truth for {data_subset} data subset"))
  dev.off()

  pred_corr_en <- cor(glmnet_predictions, y_train[,1])
  print(glue("Correlation of {pred_corr_en} for {data_subset} data subset with glmnet predictions"))

  data_frame_for_training <- data.frame(cbind(X_train, y_train_for_cv))

  predictions_nn <- predict(nnmodel, data_frame_for_training)

  png(glue("/beagle3/haky/users/saideep/github_repos/Daily-Blog-Sai/posts/2023-10-04-train_aracena_tss_predictor/figures/{data_subset}_corr_nn_pred.png"))
  plot(predictions_nn, y_train[,1], ylim=c(-1,20), ylab="ground truth flu rnaseq", xlab="NN predictions", main = glue("NN predictions vs. ground truth for {data_subset} data subset"))
  dev.off()

  pred_corr_nn <- cor(predictions_nn, y_train[,1])
  print(glue("Correlation of {pred_corr_nn} for {data_subset} data subset with NN predictions"))
}
Code
analyze_input_data <- function(data_subset) {

  X <- as.matrix(data.table::fread(glue("/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training_tss_centered/new_{data_subset}_inputs.csv")))
  y <- as.data.frame(data.table::fread(glue("/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training_tss_centered/new_{data_subset}_targets.csv")))

  png(glue("/beagle3/haky/users/saideep/github_repos/Daily-Blog-Sai/posts/2023-10-04-train_aracena_tss_predictor/figures/{data_subset}_flu_vs_NI_rnaseq.png"))
  plot(y[,1], y[,7],ylim=c(-1,20),xlim=c(-1,20), xlab="ground truth flu rnaseq", ylab="ground truth NI rnaseq", main = glue("Flu vs. NI in ground truth for {data_subset} data subset"))
  dev.off()

  corr_rnaseq <- cor(y[,1], y[,7])
  print(glue("Correlation of {corr_rnaseq} for {data_subset} data subset between flu/NI rnaseq"))

  png(glue("/beagle3/haky/users/saideep/github_repos/Daily-Blog-Sai/posts/2023-10-04-train_aracena_tss_predictor/figures/{data_subset}_flu_vs_epi_rnaseq.png"))
  plot(y[,1], X[,4766],xlim=c(-1,20), xlab="ground truth flu rnaseq", ylab="Reference CAGE epigenome monocyte derived macrophages", main = glue("Reference epi vs. ground truth for {data_subset} data subset"))
  dev.off()

  corr_io <- cor(y[,1], X[,4766])
  print(glue("Correlation of {corr_io} for {data_subset} data subset between aracena flu rnaseq and corresponding enformer prediction"))

  png(glue("/beagle3/haky/users/saideep/github_repos/Daily-Blog-Sai/posts/2023-10-04-train_aracena_tss_predictor/figures/{data_subset}_NI_vs_epi_rnaseq.png"))
  plot(y[,7], X[,4766],xlim=c(-1,20), xlab="ground truth flu rnaseq", ylab="Reference CAGE epigenome monocyte derived macrophages", main = glue("Reference epi vs. ground truth for {data_subset} data subset"))
  dev.off()

  corr_io <- cor(y[,7], X[,4766])
  print(glue("Correlation of {corr_io} for {data_subset} data subset between aracena NI rnaseq and corresponding enformer prediction"))


}

Let’s first take a look at some properties of the data itself

Code
analyze_input_data("train")
Correlation of 0.73904726191164 for train data subset between flu/NI rnaseq
Correlation of 0.228283741744927 for train data subset between aracena flu rnaseq and corresponding enformer prediction
Correlation of 0.214077042278839 for train data subset between aracena NI rnaseq and corresponding enformer prediction
Code
analyze_input_data("valid")
Correlation of 0.989184938027612 for valid data subset between flu/NI rnaseq
Correlation of 0.0643572220801176 for valid data subset between aracena flu rnaseq and corresponding enformer prediction
Correlation of 0.0819605957767975 for valid data subset between aracena NI rnaseq and corresponding enformer prediction
Code
analyze_input_data("test")
Correlation of 0.362012765998215 for test data subset between flu/NI rnaseq
Correlation of 0.126590285627353 for test data subset between aracena flu rnaseq and corresponding enformer prediction
Correlation of 0.334457542541329 for test data subset between aracena NI rnaseq and corresponding enformer prediction

Now we can use the trained models to make predictions and assess model performance.

Code
predict_on_data_subset("train", "/beagle3/haky/users/saideep/projects/aracena_modeling/elastic_net/trained_eln_RNAseq_from_HDF5_mean.rds.linear.rds")
Correlation of 0.437281473980008 for train data subset with glmnet predictions
Correlation of 0.483994813636335 for train data subset with NN predictions
Code
predict_on_data_subset("valid", "/beagle3/haky/users/saideep/projects/aracena_modeling/elastic_net/trained_eln_RNAseq_from_HDF5_mean.rds.linear.rds")
Correlation of 0.11744649852639 for valid data subset with glmnet predictions
Correlation of 0.0682833668657665 for valid data subset with NN predictions
Code
predict_on_data_subset("test", "/beagle3/haky/users/saideep/projects/aracena_modeling/elastic_net/trained_eln_RNAseq_from_HDF5_mean.rds.linear.rds")
Correlation of 0.396900747294329 for test data subset with glmnet predictions
Correlation of 0.539515763787496 for test data subset with NN predictions

It seems like for some reason this randomly selected validation subset is just tougher to predict on, showing a reduced correlation betweeen the CAGE epigenomic track in the input as well as poorer prediction performance.

Why are the predictions so poor on the validation dataset?

Code
remapped_table <- data.frame(data.table::fread("/beagle3/haky/users/saideep/projects/aracena_modeling/hdf5_training_tss_centered/remapped_table_filt.csv", header = TRUE))

rownames(remapped_table) <- remapped_table$V1

remapped_table$V1 <- NULL

remapped_table_train <- remapped_table[,remapped_table["group",]=="train"]

regions_train <- colnames(remapped_table_train)
chroms_train <- substr(regions_train, 2, 2)
table(chroms_train)
chroms_train
   1    2    3    4    9 
7687 1976 1059  748  762 
Code
remapped_table_test <- remapped_table[,remapped_table["group",]=="test"]

dim(remapped_table_test)
[1]   29 2041
Code
regions <- colnames(remapped_table_test)
chroms <- substr(regions, 2, 2)

half_test <- floor(length(regions)/2)

chroms_valid <- chroms[1:half_test]
table(chroms_valid)
chroms_valid
  1   2 
919 101 
Code
chroms_test <- chroms[(half_test+1):length(regions)]
table(chroms_test)
chroms_test
  2   8 
330 691