Commit 15ffbc18 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added input/output scaling, fixed overfitting network save, delete old checkpoints

parent 7c0603da
......@@ -119,4 +119,5 @@ dmypy.json
datasets/
files.txt
jobscript.sge.sh
*.nii.gz
\ No newline at end of file
*.nii.gz
stuff/
\ No newline at end of file
......@@ -20,8 +20,8 @@ import torch.nn as nn
import utils.modules as modules
class BrainMapperCompResUNet3D(nn.Module):
"""Architecture class for Competitive Residual DenseBlock BrainMapper 3D U-net.
class BrainMapperUNet3D(nn.Module):
"""Architecture class for Traditional BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
......@@ -47,25 +47,41 @@ class BrainMapperCompResUNet3D(nn.Module):
"""
def __init__(self, parameters):
super(BrainMapperCompResUNet3D, self).__init__()
super(BrainMapperUNet3D, self).__init__()
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
self.encoderBlock1 = modules.InCompDensEncoderBlock3D(parameters)
self.encoderBlock1 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
self.encoderBlock2 = modules.CompDensEncoderBlock3D(parameters)
self.encoderBlock3 = modules.CompDensEncoderBlock3D(parameters)
self.encoderBlock4 = modules.CompDensEncoderBlock3D(parameters)
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock2 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock3 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock4 = modules.EncoderBlock3D(parameters)
self.bottleneck = modules.CompDensBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.bottleneck = modules.ConvolutionalBlock3D(parameters)
self.decoderBlock1 = modules.CompDensDecoderBlock3D(parameters)
self.decoderBlock2 = modules.CompDensDecoderBlock3D(parameters)
self.decoderBlock3 = modules.CompDensDecoderBlock3D(parameters)
self.decoderBlock4 = modules.CompDensDecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock1 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock2 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock3 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock4 = modules.DecoderBlock3D(parameters)
self.classifier = modules.DensClassifierBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
self.classifier = modules.ClassifierBlock3D(parameters)
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
......@@ -211,9 +227,8 @@ class BrainMapperCompResUNet3D(nn.Module):
print("Initialized network parameters!")
class BrainMapperResUNet3Dshallow(nn.Module):
"""Architecture class for Residual DenseBlock BrainMapper 3D U-net.
class BrainMapperCompResUNet3D(nn.Module):
"""Architecture class for Competitive Residual DenseBlock BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
......@@ -239,25 +254,25 @@ class BrainMapperResUNet3Dshallow(nn.Module):
"""
def __init__(self, parameters):
super(BrainMapperResUNet3Dshallow, self).__init__()
super(BrainMapperCompResUNet3D, self).__init__()
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
self.encoderBlock1 = modules.DensEncoderBlock3D(parameters)
self.encoderBlock1 = modules.InCompDensEncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
self.encoderBlock2 = modules.DensEncoderBlock3D(parameters)
self.encoderBlock3 = modules.DensEncoderBlock3D(parameters)
self.encoderBlock2 = modules.CompDensEncoderBlock3D(parameters)
self.encoderBlock3 = modules.CompDensEncoderBlock3D(parameters)
self.encoderBlock4 = modules.CompDensEncoderBlock3D(parameters)
self.bottleneck = modules.DensBlock3D(parameters)
self.bottleneck = modules.CompDensBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] * 2
self.decoderBlock1 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock2 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock3 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock1 = modules.CompDensDecoderBlock3D(parameters)
self.decoderBlock2 = modules.CompDensDecoderBlock3D(parameters)
self.decoderBlock3 = modules.CompDensDecoderBlock3D(parameters)
self.decoderBlock4 = modules.CompDensDecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
self.classifier = modules.DensClassifierBlock3D(parameters)
self.classifier = modules.CompDensClassifierBlock3D(parameters)
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
......@@ -286,28 +301,38 @@ class BrainMapperResUNet3Dshallow(nn.Module):
del Y_encoder_2
Y_bottleNeck = self.bottleneck.forward(Y_encoder_3)
Y_encoder_4, Y_np4, _ = self.encoderBlock4.forward(
Y_encoder_3)
del Y_encoder_3
Y_bottleNeck = self.bottleneck.forward(Y_encoder_4)
del Y_encoder_4
Y_decoder_1 = self.decoderBlock1.forward(
Y_bottleNeck, Y_np3)
Y_bottleNeck, Y_np4)
del Y_bottleNeck, Y_np3
del Y_bottleNeck, Y_np4
Y_decoder_2 = self.decoderBlock2.forward(
Y_decoder_1, Y_np2)
Y_decoder_1, Y_np3)
del Y_decoder_1, Y_np2
del Y_decoder_1, Y_np3
Y_decoder_3 = self.decoderBlock3.forward(
Y_decoder_2, Y_np1)
Y_decoder_2, Y_np2)
del Y_decoder_2, Y_np1
del Y_decoder_2, Y_np2
probability_map = self.classifier.forward(Y_decoder_3)
Y_decoder_4 = self.decoderBlock4.forward(
Y_decoder_3, Y_np1)
del Y_decoder_3
del Y_decoder_3, Y_np1
probability_map = self.classifier.forward(Y_decoder_4)
del Y_decoder_4
return probability_map
......@@ -393,7 +418,7 @@ class BrainMapperResUNet3Dshallow(nn.Module):
print("Initialized network parameters!")
class BrainMapperResUNet3D(nn.Module):
class BrainMapperResUNet3Dshallow(nn.Module):
"""Architecture class for Residual DenseBlock BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
......@@ -420,7 +445,7 @@ class BrainMapperResUNet3D(nn.Module):
"""
def __init__(self, parameters):
super(BrainMapperResUNet3D, self).__init__()
super(BrainMapperResUNet3Dshallow, self).__init__()
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
......@@ -429,7 +454,6 @@ class BrainMapperResUNet3D(nn.Module):
parameters['input_channels'] = parameters['output_channels']
self.encoderBlock2 = modules.DensEncoderBlock3D(parameters)
self.encoderBlock3 = modules.DensEncoderBlock3D(parameters)
self.encoderBlock4 = modules.DensEncoderBlock3D(parameters)
self.bottleneck = modules.DensBlock3D(parameters)
......@@ -437,7 +461,6 @@ class BrainMapperResUNet3D(nn.Module):
self.decoderBlock1 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock2 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock3 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock4 = modules.DensDecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
self.classifier = modules.DensClassifierBlock3D(parameters)
......@@ -469,38 +492,28 @@ class BrainMapperResUNet3D(nn.Module):
del Y_encoder_2
Y_encoder_4, Y_np4, _ = self.encoderBlock4.forward(
Y_encoder_3)
Y_bottleNeck = self.bottleneck.forward(Y_encoder_3)
del Y_encoder_3
Y_bottleNeck = self.bottleneck.forward(Y_encoder_4)
del Y_encoder_4
Y_decoder_1 = self.decoderBlock1.forward(
Y_bottleNeck, Y_np4)
Y_bottleNeck, Y_np3)
del Y_bottleNeck, Y_np4
del Y_bottleNeck, Y_np3
Y_decoder_2 = self.decoderBlock2.forward(
Y_decoder_1, Y_np3)
Y_decoder_1, Y_np2)
del Y_decoder_1, Y_np3
del Y_decoder_1, Y_np2
Y_decoder_3 = self.decoderBlock3.forward(
Y_decoder_2, Y_np2)
del Y_decoder_2, Y_np2
Y_decoder_4 = self.decoderBlock4.forward(
Y_decoder_3, Y_np1)
Y_decoder_2, Y_np1)
del Y_decoder_3, Y_np1
del Y_decoder_2, Y_np1
probability_map = self.classifier.forward(Y_decoder_4)
probability_map = self.classifier.forward(Y_decoder_3)
del Y_decoder_4
del Y_decoder_3
return probability_map
......@@ -586,8 +599,8 @@ class BrainMapperResUNet3D(nn.Module):
print("Initialized network parameters!")
class BrainMapperUNet3D(nn.Module):
"""Architecture class for Traditional BrainMapper 3D U-net.
class BrainMapperResUNet3D(nn.Module):
"""Architecture class for Residual DenseBlock BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
......@@ -613,41 +626,27 @@ class BrainMapperUNet3D(nn.Module):
"""
def __init__(self, parameters):
super(BrainMapperUNet3D, self).__init__()
super(BrainMapperResUNet3D, self).__init__()
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
self.encoderBlock1 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock2 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock3 = modules.EncoderBlock3D(parameters)
self.encoderBlock1 = modules.DensEncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock4 = modules.EncoderBlock3D(parameters)
self.encoderBlock2 = modules.DensEncoderBlock3D(parameters)
self.encoderBlock3 = modules.DensEncoderBlock3D(parameters)
self.encoderBlock4 = modules.DensEncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.bottleneck = modules.ConvolutionalBlock3D(parameters)
self.bottleneck = modules.DensBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock1 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock2 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock3 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock4 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] * 2
self.decoderBlock1 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock2 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock3 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock4 = modules.DensDecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
self.classifier = modules.ClassifierBlock3D(parameters)
self.classifier = modules.DensClassifierBlock3D(parameters)
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
......
......@@ -149,11 +149,10 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
BrainMapperModel = torch.load(
training_parameters['pre_trained_path'])
else:
# BrainMapperModel = BrainMapperUNet3D(network_parameters)
BrainMapperModel = BrainMapperUNet3D(network_parameters)
# BrainMapperModel = BrainMapperResUNet3D(network_parameters)
# BrainMapperModel = BrainMapperResUNet3Dshallow(network_parameters)
BrainMapperModel = BrainMapperCompResUNet3D(network_parameters)
# BrainMapperModel = BrainMapperCompResUNet3D(network_parameters)
BrainMapperModel.reset_parameters()
......@@ -178,20 +177,13 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
use_last_checkpoint=training_parameters['use_last_checkpoint'],
experiment_directory=misc_parameters['experiments_directory'],
logs_directory=misc_parameters['logs_directory'],
checkpoint_directory=misc_parameters['checkpoint_directory']
checkpoint_directory=misc_parameters['checkpoint_directory'],
save_model_directory=misc_parameters['save_model_directory'],
final_model_output_file=training_parameters['final_model_output_file']
)
validation_loss = solver.train(train_loader, validation_loader)
model_output_path = os.path.join(
misc_parameters['save_model_directory'], training_parameters['final_model_output_file'])
create_folder(misc_parameters['save_model_directory'])
BrainMapperModel.save(model_output_path)
print("Final Model Saved in: {}".format(model_output_path))
del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver, optimizer
torch.cuda.empty_cache()
......@@ -283,7 +275,7 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
prediction_output_path=prediction_output_path,
device=misc_parameters['device'],
LogWriter=logWriter
)
)
logWriter.close()
......@@ -298,7 +290,6 @@ def evaluate_mapping(mapping_evaluation_parameters):
mapping_evaluation_parameters = {
'trained_model_path': 'path/to/model'
'data_directory': 'path/to/data'
'mapping_data_file': 'path/to/file'
'data_list': 'path/to/datalist.txt/
'prediction_output_path': 'directory-of-saved-predictions'
'batch_size': 2
......@@ -317,6 +308,7 @@ def evaluate_mapping(mapping_evaluation_parameters):
brain_mask_path = mapping_evaluation_parameters['brain_mask_path']
mean_mask_path = mapping_evaluation_parameters['mean_mask_path']
mean_reduction = mapping_evaluation_parameters['mean_reduction']
scaling_factors = mapping_evaluation_parameters['scaling_factors']
evaluations.evaluate_mapping(trained_model_path,
data_directory,
......@@ -326,6 +318,7 @@ def evaluate_mapping(mapping_evaluation_parameters):
brain_mask_path,
mean_mask_path,
mean_reduction,
scaling_factors,
device=device,
exit_on_error=exit_on_error)
......@@ -369,12 +362,19 @@ if __name__ == '__main__':
# Here we shuffle the data!
if data_parameters['data_split_flag'] == True:
print('Data is shuffling... This could take a few minutes!')
if data_parameters['data_split_flag'] == True:
if data_parameters['use_data_file'] == True:
data_test_train_validation_split(data_parameters['data_folder_name'],
data_parameters['test_percentage'],
data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'],
mean_mask_path=data_parameters['mean_mask_path'],
data_file=data_parameters['data_file'],
K_fold=data_parameters['k_fold']
)
......@@ -383,10 +383,15 @@ if __name__ == '__main__':
data_parameters['test_percentage'],
data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'],
mean_mask_path=data_parameters['mean_mask_path'],
K_fold=data_parameters['k_fold']
)
update_shuffling_flag('settings.ini')
print('Data is shuffling... Complete!')
if arguments.mode == 'train':
train(data_parameters, training_parameters,
network_parameters, misc_parameters)
......
......@@ -10,6 +10,7 @@ subject_number = None
train_list = "datasets/train.txt"
validation_list = "datasets/validation.txt"
test_list = "datasets/test.txt"
scaling_factors = "datasets/scaling_factors.pkl"
train_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
train_output_targets = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
validation_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
......
......@@ -6,6 +6,7 @@ data_list = "datasets/test.txt"
prediction_output_path = "network_predictions"
brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
mean_mask_path = "utils/mean_dr_stage2.nii.gz"
scaling_factors = "datasets/scaling_factors.pkl"
mean_reduction = True
device = 0
exit_on_error = True
\ No newline at end of file
exit_on_error = True
......@@ -72,7 +72,9 @@ class Solver():
use_last_checkpoint=True,
experiment_directory='experiments',
logs_directory='logs',
checkpoint_directory='checkpoints'
checkpoint_directory='checkpoints',
save_model_directory='saved_models',
final_model_output_file='finetuned_alldata.pth.tar'
):
self.model = model
......@@ -125,6 +127,9 @@ class Solver():
self.MNI152_T1_2mm_brain_mask = torch.from_numpy(
Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
self.save_model_directory = save_model_directory
self.final_model_output_file = final_model_output_file
def train(self, train_loader, validation_loader):
"""Training Function
......@@ -145,6 +150,8 @@ class Solver():
torch.cuda.empty_cache() # clear memory
model.cuda(self.device) # Moving the model to GPU
previous_checkpoint = None
print('****************************************************************')
print('TRAINING IS STARTING!')
print('=====================')
......@@ -231,6 +238,8 @@ class Solver():
self.early_stop = early_stop
if save_checkpoint == True:
validation_loss = np.mean(losses)
checkpoint_name = os.path.join(
self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
......@@ -238,12 +247,14 @@ class Solver():
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict()
},
filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
filename=checkpoint_name
)
# if epoch != self.start_epoch:
# os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory,
# 'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension))
if previous_checkpoint != None:
os.remove(previous_checkpoint)
previous_checkpoint = checkpoint_name
else:
previous_checkpoint = checkpoint_name
if phase == 'train':
learning_rate_scheduler.step()
......@@ -254,10 +265,18 @@ class Solver():
if self.early_stop == True:
print("ATTENTION!: Training stopped early to prevent overfitting!")
self.load_checkpoint()
break
else:
continue
model_output_path = os.path.join(
self.save_model_directory, self.final_model_output_file)
create_folder(self.save_model_directory)
model.save(model_output_path)
self.LogWriter.close()
print('----------------------------------------')
......@@ -266,6 +285,7 @@ class Solver():
end_time = datetime.now()
print('Completed At: {}'.format(end_time))
print('Training Duration: {}'.format(end_time - start_time))
print('Final Model Saved in: {}'.format(model_output_path))
print('****************************************************************')
return validation_loss
......
......@@ -15,6 +15,7 @@ TODO: Might be worth adding some information on uncertaintiy estimation, later d
"""
import os
import pickle
import numpy as np
import torch
import logging
......@@ -209,6 +210,7 @@ def evaluate_mapping(trained_model_path,
brain_mask_path,
mean_mask_path,
mean_reduction,
scaling_factors,
device=0,
mode='evaluate',
exit_on_error=False):
......@@ -225,6 +227,7 @@ def evaluate_mapping(trained_model_path,
brain_mask_path (str): Path to the MNI brain mask file
mean_mask_path (str): Path to the dualreg subject mean mask
mean_reduction (bool): Flag indicating if the targets should be de-meaned using the mean_mask_path
scaling_factors (str): Path to the scaling factor file
device (str/int): Device type used for training (int - GPU id, str- CPU)
mode (str): Current run mode or phase
exit_on_error (bool): Flag that triggers the raising of an exception
......@@ -276,7 +279,7 @@ def evaluate_mapping(trained_model_path,
print("Mapping Volume {}/{}".format(volume_index+1, len(file_paths)))
# Generate volume & header
predicted_complete_volume, predicted_volume, header, xform = _generate_volume_map(
file_path, model, device, cuda_available, brain_mask_path, mean_mask_path, mean_reduction)
file_path, model, device, cuda_available, brain_mask_path, mean_mask_path, scaling_factors, mean_reduction)
# Generate New Header Affine
......@@ -298,12 +301,14 @@ def evaluate_mapping(trained_model_path,
output_complete_nifti_image = Image(
predicted_complete_volume, header=header, xform=xform)
output_complete_nifti_path = output_nifti_path + '_complete'
output_complete_nifti_path = os.path.join(
prediction_output_path, volumes_to_be_used[volume_index]) + '_complete'
if '.nii' not in output_complete_nifti_path:
output_complete_nifti_path += '.nii.gz'
output_complete_nifti_image.save(output_complete_nifti_path)
output_complete_nifti_image.save(
output_complete_nifti_path)
log.info("Processed: " + volumes_to_be_used[volume_index] + " " + str(
volume_index + 1) + " out of " + str(len(volumes_to_be_used)))
......@@ -323,7 +328,7 @@ def evaluate_mapping(trained_model_path,
log.info("rsfMRI Generation Complete")
def _generate_volume_map(file_path, model, device, cuda_available, brain_mask_path, mean_mask_path, mean_reduction=False):
def _generate_volume_map(file_path, model, device, cuda_available, brain_mask_path, mean_mask_path, scaling_factors, mean_reduction=False):
"""rsfMRI Volume Generator
This function uses the trained model to generate a new volume
......@@ -335,6 +340,7 @@ def _generate_volume_map(file_path, model, device, cuda_available, brain_mask_pa
cuda_available (bool): Flag indicating if a cuda-enabled GPU is present
brain_mask_path (str): Path to the MNI brain mask file
mean_mask_path (str): Path to the dualreg subject mean mask
scaling_factors (str): Path to the scaling factor file
mean_reduction (bool): Flag indicating if the targets should be de-meaned using the mean_mask_path
Returns
......@@ -345,41 +351,87 @@ def _generate_volume_map(file_path, model, device, cuda_available, brain_mask_pa
volume, header, xform = data_utils.load_and_preprocess_evaluation(
file_path)
if len(volume.shape) == 4:
if len(volume.shape) == 5:
volume = volume
else:
volume = volume[np.newaxis, np.newaxis, :, :, :]
volume = _scale_input(volume, scaling_factors)
volume = torch.tensor(volume).type(torch.FloatTensor)
MNI152_T1_2mm_brain_mask = torch.from_numpy(Image(brain_mask_path).data)
if mean_reduction == True:
mean_mask = torch.from_numpy(Image(mean_mask_path).data[:, :, :, 0])
if cuda_available and (type(device) == int):
volume = volume.cuda(device)
MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(device)
if mean_reduction == True:
mean_mask = mean_mask.cuda(device)
output = model(volume)
output = torch.mul(output, MNI152_T1_2mm_brain_mask)
output = (output.cpu().numpy()).astype('float32')
output = np.squeeze(output)
output = _rescale_output(output, scaling_factors)
if mean_reduction==True:
predicted_complete_volume = torch.add(output, mean_mask)
predicted_complete_volume = (predicted_complete_volume.cpu().numpy()).astype('float32')
predicted_complete_volume = np.squeeze(predicted_complete_volume)
MNI152_T1_2mm_brain_mask = Image(brain_mask_path).data
if mean_reduction == True:
mean_mask = Image(mean_mask_path).data[:, :, :, 0]