Commit 9a98d8cc authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

created standalone svd/pca analysis script for pretrained autoencoders

parent 2e9a8588
""" SVD Evaluator
Description:
This file contains the required fuctions for performing single value decomposition testing.
This file is designed to be a standalone package, intendead to be run separately from the main network.
Usage:
To use content from this folder, import the functions and instantiate them as you wish to use them:
from utils.svd import function_name
"""
import os
import sys
import pickle
import numpy as np
import logging
import h5py
import argparse
import torch
import torch.nn as nn
import modules
import data_utils as data_utils
import pandas as pd
import itertools
import seaborn as sns
import matplotlib.pyplot as plt
from settings import Settings
from torch.nn.init import _calculate_fan_in_and_fan_out as calculate_fan
from common_utils import create_folder
from fsl.data.image import Image
from fsl.utils.image.roi import roi
from fsl.utils.image.resample import resampleToPixdims
sns.set()
# First we need to create two networks
class Encoder3D(nn.Module):
"""Architecture class for loading the encoder bit of a 3D Autoencoder.
"""
def __init__(self, parameters):
super(Encoder3D, self).__init__()
original_kernel_size = parameters['kernel_size']
original_stride = parameters['convolution_stride']
# Encoder Path
parameters['kernel_size'] = parameters['first_kernel_size']
parameters['convolution_stride'] = parameters['first_convolution_stride']
self.encoderBlocks = nn.ModuleList([modules.ResNetEncoderBlock3D(parameters)])
parameters['kernel_size'] = original_kernel_size
parameters['convolution_stride'] = original_stride
self.equal_channels_blocks = 0
for _ in range(parameters['number_of_encoder_blocks']):
if parameters['output_channels'] < parameters['max_number_channels']:
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
else:
parameters['input_channels'] = parameters['output_channels']
self.equal_channels_blocks += 1
self.encoderBlocks.append(modules.ResNetEncoderBlock3D(parameters))
parameters['convolution_stride'] = parameters['transformer_blocks_stride']
self.input_channels = parameters['output_channels']
self.output_channels = parameters['output_channels']
self.convolution_stride = parameters['convolution_stride']
self.kernel_size = parameters['kernel_size']
def return_docoder_info(self):
return self.equal_channels_blocks, self.input_channels, self.output_channels, self.convolution_stride, self.kernel_size
def forward(self, X):
"""Forward pass for 3D Encoder
"""
# Encoder
Y_encoder_sizes = []
for encoderBlock in self.encoderBlocks:
X = encoderBlock.forward(X)
Y_encoder_sizes.append(X.size())
Y_encoder_sizes = Y_encoder_sizes[:-1][::-1]
Y_encoder_sizes_lenght = len(Y_encoder_sizes)
return X, Y_encoder_sizes, Y_encoder_sizes_lenght
def save(self, path):
"""Model Saver
Function saving the model with all its parameters to a given path.
The path must end with a *.model argument.
Args:
path (str): Path string
"""
print("Saving Model... {}".format(path))
torch.save(self, path)
@property
def test_if_cuda(self):
"""Cuda Test
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
"""
return next(self.parameters()).is_cuda
def reset_parameters(self, custom_weight_reset_flag):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
Args:
custom_weight_reset_flag (bool): Flag indicating if the modified weight initialisation approach should be used.
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
submodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(submodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(submodule.weight)
std = np.divide(gain, np.sqrt(fan))
submodule.weight.data.normal_(0, std)
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
subsubmodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(subsubmodule.weight)
std = np.divide(gain, np.sqrt(fan))
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!")
class Decoder3D(nn.Module):
"""Architecture class for loading the encoder bit of a 3D Autoencoder.
"""
def __init__(self, parameters, equal_channels_blocks, input_channels, output_channels, convolution_stride, kernel_size):
super(Decoder3D, self).__init__()
# Decoder
parameters['input_channels'] = input_channels
parameters['output_channels'] = output_channels
parameters['convolution_stride'] = convolution_stride
parameters['kernel_size'] = kernel_size
if equal_channels_blocks != 0:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
parameters['output_channels'] = parameters['output_channels'] // 2
if equal_channels_blocks != 0:
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
else:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters)])
for _ in range(parameters['number_of_encoder_blocks'] - equal_channels_blocks - 1):
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
parameters['input_channels'] = parameters['output_channels']
self.decoderBlocks.append(modules.ResNetClassifierBlock3D(parameters))
def forward(self, X, Y_encoder_sizes, Y_encoder_sizes_lenght):
"""Forward pass for 3D Decoder
"""
# Decoder
for index, decoderBlock in enumerate(self.decoderBlocks):
if index < Y_encoder_sizes_lenght:
X = decoderBlock.forward(X, Y_encoder_sizes[index])
else:
X = decoderBlock.forward(X)
return X
def save(self, path):
"""Model Saver
Function saving the model with all its parameters to a given path.
The path must end with a *.model argument.
Args:
path (str): Path string
"""
print("Saving Model... {}".format(path))
torch.save(self, path)
@property
def test_if_cuda(self):
"""Cuda Test
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
"""
return next(self.parameters()).is_cuda
def reset_parameters(self, custom_weight_reset_flag):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
Args:
custom_weight_reset_flag (bool): Flag indicating if the modified weight initialisation approach should be used.
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
submodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(submodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(submodule.weight)
std = np.divide(gain, np.sqrt(fan))
submodule.weight.data.normal_(0, std)
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
subsubmodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(subsubmodule.weight)
std = np.divide(gain, np.sqrt(fan))
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!")
def load_pretrained_models(EncoderModel, DecoderModel, save_model_directory, pretrained_network_name):
""" Pretrained loader
"""
sys.path.insert(0, './') # https://github.com/pytorch/pytorch/issues/3678
pretrained_model_state_dict = torch.load(os.path.join(save_model_directory, pretrained_network_name)).state_dict()
EncoderModel_state_dict = EncoderModel.state_dict()
DecoderModel_state_dict = DecoderModel.state_dict()
half_point = len(pretrained_model_state_dict) // 2 + 1
counter = 1
for key, _ in pretrained_model_state_dict.items():
if counter <= half_point:
EncoderModel_state_dict.update({key : pretrained_model_state_dict[key]})
counter+=1
else:
if key in DecoderModel_state_dict:
DecoderModel_state_dict.update({key : pretrained_model_state_dict[key]})
EncoderModel.load_state_dict(EncoderModel_state_dict)
DecoderModel.load_state_dict(DecoderModel_state_dict)
return EncoderModel, DecoderModel
def evaluate_data(pretrained_model,
data_directory,
mapping_data_file,
mapping_targets_file,
data_list,
prediction_output_path,
prediction_output_database_name,
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
scale_volumes_flag,
normalize_flag,
minus_one_scaling_flag,
negative_flag,
outlier_flag,
shrinkage_flag,
hard_shrinkage_flag,
crop_flag,
device=0,
exit_on_error=False,
cross_domain_x2x_flag=False,
cross_domain_y2y_flag=False,
mode='evaluate'):
"""Model Evaluator
This function generates the rsfMRI arrays for the given inputs
"""
with open(data_list) as data_list_file:
volumes_to_be_used = data_list_file.read().splitlines()
cuda_available = torch.cuda.is_available()
if type(device) == int:
if cuda_available:
model = pretrained_model
torch.cuda.empty_cache()
model.cuda(device)
else:
device = 'cpu'
if (type(device) == str) or not cuda_available:
model = pretrained_model
model.eval()
create_folder(prediction_output_path)
if cross_domain_y2y_flag == True:
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file)
else:
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file)
output_database_path = os.path.join(prediction_output_path, prediction_output_database_name)
if os.path.exists(output_database_path):
os.remove(output_database_path)
output_database_handle = h5py.File(output_database_path, 'w')
with torch.no_grad():
for volume_index, file_path in enumerate(file_paths):
print("Mapping Volume {}/{}".format(volume_index+1, len(file_paths)))
# Generate volume & header
subject = volumes_to_be_used[volume_index]
output = _generate_volume_map(file_path,
subject,
model,
device,
cuda_available,
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
scale_volumes_flag,
normalize_flag,
minus_one_scaling_flag,
negative_flag,
outlier_flag,
shrinkage_flag,
hard_shrinkage_flag,
crop_flag,
cross_domain_x2x_flag,
cross_domain_y2y_flag)
group = output_database_handle.create_group(subject)
group.create_dataset('output', data=output)
output_database_handle.close()
return volumes_to_be_used
def _generate_volume_map(file_path,
subject,
model,
device,
cuda_available,
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
scale_volumes_flag,
normalize_flag,
minus_one_scaling_flag,
negative_flag,
outlier_flag,
shrinkage_flag,
hard_shrinkage_flag,
crop_flag,
cross_domain_x2x_flag,
cross_domain_y2y_flag
):
"""rsfMRI Volume Generator
"""
volume, _, _ = data_utils.load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag)
if mean_regression_flag == True:
if mean_regression_all_flag == True:
volume = _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_y2y_flag)
scaling_parameters = [-0.0626, 0.1146, -14.18, 16.9475]
else:
scaling_parameters = [0.0, 0.2, -14.18, 16.9475]
elif mean_subtraction_flag == True:
scaling_parameters = [0.0, 0.2, 0.0, 10.0]
if scale_volumes_flag == True:
volume = _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_y2y_flag)
if len(volume.shape) == 5:
volume = volume
else:
volume = volume[np.newaxis, np.newaxis, :, :, :]
volume = torch.tensor(volume).type(torch.FloatTensor)
if cuda_available and (type(device) == int):
volume = volume.cuda(device)
output = model(volume)[0]
output = (output.cpu().numpy()).astype('float32')
output = np.squeeze(output)
return output
def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_y2y_flag):
"""Input Scaling
This function reads the scaling factors from the saved file and then scales the data.
Args:
volume (np.array): Numpy array representing the un-scalled volume.
scaling_parameters (list): List of scaling parameters.
normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
negative_flag (bool): Flag indicating if all the negative values should be 0-ed.
outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
scaled_volume (np.array): Scaled volume
"""
if cross_domain_y2y_flag == True:
_, _, min_value, max_value = scaling_parameters
else:
min_value, max_value, _, _ = scaling_parameters
if shrinkage_flag == True:
if cross_domain_y2y_flag == True:
lambd = 3.0
else:
lambd = 0.003 # Hard coded, equivalent to tht 1p and 99p values across the whole population in UKBB
if hard_shrinkage_flag == True:
volume = _hard_shrinkage(volume, lambd)
elif hard_shrinkage_flag == False:
volume = _soft_shrinkage(volume, lambd)
min_value += lambd
max_value -= lambd
if negative_flag == True:
volume[volume < 0.0] = 0.0
min_value = 0.0
if outlier_flag == True:
volume[volume > max_value] = max_value
volume[volume < min_value] = min_value
if normalize_flag == True:
# Normalization to [0, 1]
scaled_volume = np.divide(np.subtract(volume, min_value), np.subtract(max_value, min_value))
elif minus_one_scaling_flag == True:
# Scaling between [-1, 1]
scaled_volume = np.add(-1.0, np.multiply(2.0, np.divide(np.subtract(volume, min_value), np.subtract(max_value, min_value))))
# Else, no scaling occus, but the other flags can still hold true if the scaling flag is true!
return scaled_volume
def _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_y2y_flag):
""" Inputn Regression
This function regresse the group mean from the input volume using the saved regression weights.
TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.
Args:
volume (np.array): Unregressed volume
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the target group mean volume
regression_factors (str): Path to the linear regression weights file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
regressed_volume (np.array): Linear regressed volume
"""
if cross_domain_y2y_flag == True:
weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI']
if crop_flag == False:
group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
elif crop_flag == True:
group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
else:
weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
if crop_flag == False:
group_mean = Image(dmri_mean_mask_path).data
elif crop_flag == True:
group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
return regressed_volume
def _rescale_output(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_x2x_flag):
"""Output Rescaling
This function reads the scaling factors from the saved file and then scales the data.
Args:
volume (np.array): Unscalled volume
scaling_parameters (list): List of scaling parameters.
normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
negative_flag (bool): Flag indicating if all the negative values should be 0-ed.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
Returns:
rescaled_volume (np.array): Rescaled volume
"""
if cross_domain_x2x_flag == True:
min_value, max_value, _, _ = scaling_parameters
else:
_, _, min_value, max_value = scaling_parameters
if shrinkage_flag == True:
if cross_domain_x2x_flag == True:
lambd = 0.003
else:
lambd = 3.0
if hard_shrinkage_flag == True:
pass
elif hard_shrinkage_flag == False:
min_value += lambd
max_value -= lambd
if negative_flag == True:
min_value = 0.0
if normalize_flag == True:
# Normalization to [0, 1]
rescaled_volume = np.add(np.multiply(volume, np.subtract(max_value, min_value)), min_value)
elif minus_one_scaling_flag == True:
# Scaling between [-1, 1]
rescaled_volume = np.add(np.multiply(np.divide(np.add(volume, 1), 2), np.subtract(max_value, min_value)), min_value)
# Else, no rescaling occus, but the other flags can still hold true if the scaling flag is true!
return rescaled_volume
def _hard_shrinkage(volume, lambd):
""" Hard Shrinkage
This function performs a hard shrinkage on the volumes.
volume = { x , x > lambd | x < -lambd
0 , x e [-lambd, lambd]
}
Args:
volume (np.array): Unshrunken volume
lambd (float): Threshold parameter
Returns:
volume (np.array) : Hard shrunk volume
"""
volume[np.where(np.logical_and(volume>-lambd, volume<lambd))] = 0
return volume
def _soft_shrinkage(volume, lambd):
""" Soft Shrinkage
This function performs a soft shrinkage on the volumes.
volume = { x + lambd , x < -lambd
0 , x e [-lambd, lambd]
x - lambd , x > lambd
}
Args:
volume (np.array): Unshrunken volume
lambd (float): Threshold parameter