from torch.utils.data.distributed import DistributedSampler
class HDF5DatasetRefAracena(Dataset):
def __init__ (self , enformer_ref_targets, aracena_targets, mapping_table, individual, summarization= "mean" , bin_size= 128 ):
super ().__init__ ()
import pandas as pd
self .bin_size = bin_size
self .summarization = summarization
self .individual = individual
self .mapping_table = mapping_table
self .files = {}
self .regions = {}
if summarization == "mean" :
with h5py.File(enformer_ref_targets, 'r' ) as ert:
self .files['ref_targets' ] = torch.Tensor(ert['ref_epigenome' ][:,:,:].swapaxes(0 ,1 )).float ()
self .regions['ref_targets' ] = ert['regions' ][:,:,:]
with h5py.File(aracena_targets, 'r' ) as at:
self .files['aracena_targets' ] = torch.Tensor(at['targets' ][:,:,:].swapaxes(0 ,1 )).float ()
self .regions['aracena_targets' ] = at['regions' ][:,:,:]
elif summarization == "sum" :
with h5py.File(enformer_ref_targets, 'r' ) as ert:
self .files['ref_targets' ] = torch.mul(torch.Tensor(ert['ref_epigenome' ][:,:,:].swapaxes(0 ,1 )).float (),bin_size)
self .regions['ref_targets' ] = ert['regions' ][:,:,:]
with h5py.File(aracena_targets, 'r' ) as at:
self .files['aracena_targets' ] = torch.mul(torch.Tensor(at['targets' ][:,:,:].swapaxes(0 ,1 )).float (), bin_size)
self .regions['aracena_targets' ] = at['regions' ][:,:,:]
def __getitem__ (self , index):
region = self .mapping_table.columns[index]
dt = {}
dt['ref_targets' ] = self .files['ref_targets' ][: , :, int (self .mapping_table.loc["ref-epigenome" , region])].swapaxes(0 ,1 )
dt['aracena_targets' ] = self .files['aracena_targets' ][:, :, int (self .mapping_table.loc[self .individual, region])].swapaxes(0 ,1 )
dt['ref_region' ] = self .regions['ref_targets' ][0 ,:,int (self .mapping_table.loc["ref-epigenome" , region])]
dt['ara_region' ] = self .regions['aracena_targets' ][0 ,:,int (self .mapping_table.loc[self .individual, region])]
return dt
def __len__ (self ):
return len (self .mapping_table.columns)
def prepare_ddp_dataloader(rank, world_size, path_to_input, path_to_targets, group, individual, batch_size= 32 , pin_memory= False , num_workers= 0 ):
dataset = HDF5DatasetRefAracena(enformer_pred_path, aracena_hdf5_path, mapping_table_subgroup, individual_cur, summarization= "mean" )
sampler = DistributedSampler(dataset, num_replicas= world_size, rank= rank, shuffle= False , drop_last= False )
dataloader = DataLoader(dataset, batch_size= batch_size, pin_memory= pin_memory, num_workers= num_workers, drop_last= False , shuffle= False , sampler= sampler)
return dataloader