data_utils.py 7.25 KB
Newer Older
Andrei Roibu's avatar
Andrei Roibu committed
1
2
3
4
"""Data Processing Functions

Description:

5
    This file contains the functions required for reading and loading the data into the network and preparing the various data files.
Andrei Roibu's avatar
Andrei Roibu committed
6

7
8
9
10
11
Usage:

    To use content from this folder, import the functions and instantiate them as you wish to use them:

        from utils.data_utils import function_name
Andrei Roibu's avatar
Andrei Roibu committed
12
13
14

"""

15
16
import os
import numpy as np
17
import torch
18
import torch.utils.data as data
19
import h5py
20
21
from fsl.data.image import Image
from fsl.utils.image.resample import resampleToPixdims
22
from fsl.utils.image.roi import roi
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
23

24
class DataMapper(data.Dataset):
25
    """Data Mapper Class
26
27
28
29

    This class represents a generic parent class for mapping between keys and data samples.
    The class represents a subclass/child class of data.Dataset, inheriting its functionality.
    This class is composed of a __init__ constructor, a __getitem__(), supporting fetching a data sample for a given key, and __len__(), which returns the size of the dataset.
30
    This class also has several other helper functions.
31
32

    Args:
33
34
        X (hdf5 Database): Database containing the input preprocessed volumes.
        y (hdf5 Database): Database containing the target preprocessed volumed.
35

36
    Returns:
37
38
        X_volume (torch.tensor): Tensor representation of the input data
        y_volume (torch.tensor): Tensor representation of the output data
39
        
40
    """
41

42
43
44
45
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
46
    def __getitem__(self, index):
47
48
        X_volume = torch.from_numpy(self.X[index])
        y_volume = torch.from_numpy(self.y[index])
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
49
50
51

        return X_volume, y_volume

52
53
    def __len__(self):
        return len(self.y)
54

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
55

56
def get_datasets(data_parameters):
57
58
59
60
61
62
63
    """Data Loader Function.

    This function loads the various data file and returns the relevand mapped datasets.

    Args:
        data_parameters (dict): Dictionary containing relevant information for the datafiles.
        data_parameters = {
64
65
66
67
68
69
                            data_folder_name = "datasets"
                            input_data_train = "input_data_train.h5"
                            target_data_train = "target_data_train.h5"
                            input_data_validation = "input_data_validation.h5"
                            target_data_validation = "target_data_validation.h5"
                           }
70

71
    Returns:
72
        touple: the relevant train and validation datasets
73
74
    """

75
76
77
78
79
    X_train_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["input_data_train"]), 'r')
    y_train_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["target_data_train"]), 'r')
    
    X_validation_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["input_data_validation"]), 'r')
    y_validation_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["target_data_validation"]), 'r')
80

81
    return (
82
83
        DataMapper( X_train_data['input'][()], y_train_data['target'][()] ),
        DataMapper( X_validation_data['input'][()], y_validation_data['target'][()] )
84
    )
85

86

87
def load_file_paths(data_directory, data_list, mapping_data_file, targets_directory=None, target_file=None):
88
    """File Loader
89

90
    This function returns a list of combined file paths for the input and output data.
91
92
93
94

    Args:
        data_directory (str): Path to input data directory
        data_list (str): Path to a .txt file containing the input files for consideration
95
        mapping_data_file (str): Path to the input files
96
        targets_directory (str): Path to labelled data (Y-equivalent); None if during evaluation.
97
98
99

    Returns:
        file_paths (list): List containing the input data and target labelled output data
100
        volumes_to_be_used (list): List containing the volumes that will be used
101
102

    Raises:
103
        ValueError: "Invalid data entry - check code and data entry format!"
104
105
    """

106
    volumes_to_be_used = load_subjects_from_path(data_directory, data_list)
107

108
109
    if targets_directory == None or target_file == None:
        file_paths = [[os.path.join(data_directory, volume, mapping_data_file)]
110
                      for volume in volumes_to_be_used]
111
    else:
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
112
        file_paths = [[os.path.join(data_directory, volume, mapping_data_file), os.path.join(
113
            targets_directory, volume)] for volume in volumes_to_be_used]
114

Andrei Roibu's avatar
Andrei Roibu committed
115
    return file_paths, volumes_to_be_used 
116

117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def load_subjects_from_path(data_directory, data_list):
    """ Text File Reader

    This function returns a list of combined file paths for the input and output data.

    Args:
        data_directory (str): Path to input data directory
        data_list (str): Path to a .txt file containing the input files for consideration

    Returns:
        volumes_to_be_used (list): List containing the volumes that will be used

    """

    if data_list:
        with open(data_list) as data_list_file:
            volumes_to_be_used = data_list_file.read().splitlines()
    else:
        volumes_to_be_used = [files for files in os.listdir(data_directory)]

    return volumes_to_be_used


141
def load_and_preprocess_evaluation(file_path, crop_flag):
142
    """Load & Preprocessing before evaluation
143

144
    This function loads a nifty file and returns its volume and header information
145
146

    Args:
147
        file_path (str): Path to the desired file
148
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
149

150
151
152
    Returns:
        volume (np.array): Array of training image data of data type dtype.
        header (class): 'nibabel.nifti1.Nifti1Header' class object, containing image metadata
153
154
155
156
        xform (np.array): Array of shape (4, 4), containing the adjusted voxel-to-world transformation for the spatial dimensions of the resampled data

    Raises:
        ValueError: "Orientation value is invalid. It must be either >>coronal<<, >>axial<< or >>sagital<< "
157
158
    """

159
    original_image = Image(file_path[0])
160
161
162
163
164
165
166
167
168
169

    if crop_flag == False:
        volume, xform = resampleToPixdims(original_image, (2, 2, 2))
        header = Image(volume, header=original_image.header, xform=xform).header
    elif crop_flag == True:
        resampled, xform = resampleToPixdims(original_image, (2, 2, 2))
        resampled = Image(resampled, header=original_image.header, xform=xform)
        cropped = roi(resampled,((9,81),(10,100),(0,77)))
        volume = cropped.data
        header = cropped.header
170

171
    return volume, header, xform
172

173

174
175
def load_and_preprocess_targets(target_path, mean_mask_path):
    """Load & Preprocessing targets before evaluation
176

177
    This function loads a nifty file and returns its volume, a de-meaned volume and header information
178
179

    Args:
180
181
        file_path (str): Path to the desired target file
        mean_mask_path (str): Path to the dualreg subject mean mask
182

183
    Returns:
184
185
186
187
188
        target (np.array): Array of training image data of data type dtype.
        target_demeaned (np.array): Array of training data from which the group mean has been subtracted

    Raises:
        ValueError: "Orientation value is invalid. It must be either >>coronal<<, >>axial<< or >>sagital<< "
189
190
    """

191
192
193
    target = Image(target_path[0]).data[:, :, :, 0]
    target_demeaned = np.subtract(
        target, Image(mean_mask_path).data[:, :, :, 0])
194

195
196
    return target, target_demeaned

197