Commit 92f1b83d authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

Merge branch 'autoencoder' into 'master'

Autoencoder

See merge request !2
parents bd5515dc 159001d8
...@@ -120,4 +120,7 @@ datasets/ ...@@ -120,4 +120,7 @@ datasets/
files.txt files.txt
jobscript.sge.sh jobscript.sge.sh
*.nii.gz *.nii.gz
stuff/ stuff/
\ No newline at end of file test/*
.DS_Store
logs/
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import utils.modules as modules import utils.modules as modules
from torch.nn.init import _calculate_fan_in_and_fan_out as calculate_fan
class BrainMapperAE3D(nn.Module): class BrainMapperAE3D(nn.Module):
...@@ -80,6 +81,10 @@ class BrainMapperAE3D(nn.Module): ...@@ -80,6 +81,10 @@ class BrainMapperAE3D(nn.Module):
self.transformerBlock4 = modules.ResNetBlock3D(parameters) self.transformerBlock4 = modules.ResNetBlock3D(parameters)
self.transformerBlock5 = modules.ResNetBlock3D(parameters) self.transformerBlock5 = modules.ResNetBlock3D(parameters)
self.transformerBlock6 = modules.ResNetBlock3D(parameters) self.transformerBlock6 = modules.ResNetBlock3D(parameters)
self.transformerBlock7 = modules.ResNetBlock3D(parameters)
self.transformerBlock8 = modules.ResNetBlock3D(parameters)
self.transformerBlock9 = modules.ResNetBlock3D(parameters)
self.transformerBlock10 = modules.ResNetBlock3D(parameters)
# Decoder # Decoder
...@@ -125,6 +130,10 @@ class BrainMapperAE3D(nn.Module): ...@@ -125,6 +130,10 @@ class BrainMapperAE3D(nn.Module):
X = self.transformerBlock4.forward(X) X = self.transformerBlock4.forward(X)
X = self.transformerBlock5.forward(X) X = self.transformerBlock5.forward(X)
X = self.transformerBlock6.forward(X) X = self.transformerBlock6.forward(X)
X = self.transformerBlock7.forward(X)
X = self.transformerBlock8.forward(X)
X = self.transformerBlock9.forward(X)
X = self.transformerBlock10.forward(X)
# Decoder # Decoder
...@@ -214,5 +223,10 @@ class BrainMapperAE3D(nn.Module): ...@@ -214,5 +223,10 @@ class BrainMapperAE3D(nn.Module):
for _, subsubmodule in submodule.named_children(): for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.PReLU, torch.nn.Dropout3d, torch.nn.MaxPool3d)) == False: if isinstance(subsubmodule, (torch.nn.PReLU, torch.nn.Dropout3d, torch.nn.MaxPool3d)) == False:
subsubmodule.reset_parameters() subsubmodule.reset_parameters()
# if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
# gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
# fan, _ = calculate_fan(subsubmodule.weight)
# std = np.divide(gain, np.sqrt(fan))
# subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!") print("Initialized network parameters!")
\ No newline at end of file
This diff is collapsed.
# FunctionMapper # BrainMapper
This project will aim to address one of the big challenges in imaging-neuroscience: that of how a brain’s functional connectivity, represented by resting-state maps, can be predicted from structural connectivity information obtained from dw-MRI. This project will aim to address one of the big challenges in imaging-neuroscience: that of how a brain’s functional connectivity, represented by resting-state maps, can be predicted from structural connectivity information obtained from dw-MRI.
......
...@@ -40,9 +40,8 @@ import torch.utils.data as data ...@@ -40,9 +40,8 @@ import torch.utils.data as data
import numpy as np import numpy as np
from solver import Solver from solver import Solver
# from BrainMapperUNet import BrainMapperUNet3D, BrainMapperResUNet3D, BrainMapperResUNet3Dshallow, BrainMapperCompResUNet3D
from BrainMapperAE import BrainMapperAE3D from BrainMapperAE import BrainMapperAE3D
from utils.data_utils import get_datasets, data_test_train_validation_split, update_shuffling_flag, create_folder from utils.data_utils import get_datasets, data_preparation, update_shuffling_flag, create_folder
import utils.data_evaluation_utils as evaluations import utils.data_evaluation_utils as evaluations
from utils.data_logging_utils import LogWriter from utils.data_logging_utils import LogWriter
...@@ -150,16 +149,12 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -150,16 +149,12 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
BrainMapperModel = torch.load( BrainMapperModel = torch.load(
training_parameters['pre_trained_path']) training_parameters['pre_trained_path'])
else: else:
# BrainMapperModel = BrainMapperUNet3D(network_parameters)
# BrainMapperModel = BrainMapperResUNet3D(network_parameters)
# BrainMapperModel = BrainMapperResUNet3Dshallow(network_parameters)
# BrainMapperModel = BrainMapperCompResUNet3D(network_parameters)
BrainMapperModel = BrainMapperAE3D(network_parameters) BrainMapperModel = BrainMapperAE3D(network_parameters)
BrainMapperModel.reset_parameters() BrainMapperModel.reset_parameters()
optimizer = torch.optim.Adam optimizer = torch.optim.Adam
# optimizer = torch.optim.AdamW
solver = Solver(model=BrainMapperModel, solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'], device=misc_parameters['device'],
...@@ -259,28 +254,27 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva ...@@ -259,28 +254,27 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
# TODO - NEED TO UPDATE THE DATA FUNCTIONS! # TODO - NEED TO UPDATE THE DATA FUNCTIONS!
logWriter = LogWriter(number_of_classes=network_parameters['number_of_classes'],
logs_directory=misc_parameters['logs_directory'],
experiment_name=training_parameters['experiment_name']
)
prediction_output_path = os.path.join(misc_parameters['experiments_directory'], prediction_output_path = os.path.join(misc_parameters['experiments_directory'],
training_parameters['experiment_name'], training_parameters['experiment_name'],
evaluation_parameters['saved_predictions_directory'] evaluation_parameters['saved_predictions_directory']
) )
_ = evaluations.evaluate_dice_score(trained_model_path=evaluation_parameters['trained_model_path'], evaluations.evaluate_correlation(trained_model_path=evaluation_parameters['trained_model_path'],
number_of_classes=network_parameters['number_of_classes'], data_directory=evaluation_parameters['data_directory'],
data_directory=evaluation_parameters['data_directory'], mapping_data_file=mapping_evaluation_parameters['mapping_data_file'],
targets_directory=evaluation_parameters['targets_directory'], target_data_file=evaluation_parameters['targets_directory'],
data_list=evaluation_parameters['data_list'], data_list=evaluation_parameters['data_list'],
orientation=evaluation_parameters['orientation'], prediction_output_path=prediction_output_path,
prediction_output_path=prediction_output_path, brain_mask_path=mapping_evaluation_parameters['brain_mask_path'],
device=misc_parameters['device'], rsfmri_mean_mask_path=mapping_evaluation_parameters[
LogWriter=logWriter 'rsfmri_mean_mask_path'],
) dmri_mean_mask_path=mapping_evaluation_parameters[
'dmri_mean_mask_path'],
logWriter.close() mean_regression=mapping_evaluation_parameters['mean_regression'],
scaling_factors=mapping_evaluation_parameters['scaling_factors'],
regression_factors=mapping_evaluation_parameters['regression_factors'],
device=misc_parameters['device'],
)
def evaluate_mapping(mapping_evaluation_parameters): def evaluate_mapping(mapping_evaluation_parameters):
...@@ -306,12 +300,15 @@ def evaluate_mapping(mapping_evaluation_parameters): ...@@ -306,12 +300,15 @@ def evaluate_mapping(mapping_evaluation_parameters):
mapping_data_file = mapping_evaluation_parameters['mapping_data_file'] mapping_data_file = mapping_evaluation_parameters['mapping_data_file']
data_list = mapping_evaluation_parameters['data_list'] data_list = mapping_evaluation_parameters['data_list']
prediction_output_path = mapping_evaluation_parameters['prediction_output_path'] prediction_output_path = mapping_evaluation_parameters['prediction_output_path']
dmri_mean_mask_path = mapping_evaluation_parameters['dmri_mean_mask_path']
rsfmri_mean_mask_path = mapping_evaluation_parameters['rsfmri_mean_mask_path']
device = mapping_evaluation_parameters['device'] device = mapping_evaluation_parameters['device']
exit_on_error = mapping_evaluation_parameters['exit_on_error'] exit_on_error = mapping_evaluation_parameters['exit_on_error']
brain_mask_path = mapping_evaluation_parameters['brain_mask_path'] brain_mask_path = mapping_evaluation_parameters['brain_mask_path']
mean_mask_path = mapping_evaluation_parameters['mean_mask_path'] mean_regression = mapping_evaluation_parameters['mean_regression']
mean_reduction = mapping_evaluation_parameters['mean_reduction'] mean_subtraction = mapping_evaluation_parameters['mean_subtraction']
scaling_factors = mapping_evaluation_parameters['scaling_factors'] scaling_factors = mapping_evaluation_parameters['scaling_factors']
regression_factors = mapping_evaluation_parameters['regression_factors']
evaluations.evaluate_mapping(trained_model_path, evaluations.evaluate_mapping(trained_model_path,
data_directory, data_directory,
...@@ -319,9 +316,12 @@ def evaluate_mapping(mapping_evaluation_parameters): ...@@ -319,9 +316,12 @@ def evaluate_mapping(mapping_evaluation_parameters):
data_list, data_list,
prediction_output_path, prediction_output_path,
brain_mask_path, brain_mask_path,
mean_mask_path, dmri_mean_mask_path,
mean_reduction, rsfmri_mean_mask_path,
mean_regression,
mean_subtraction,
scaling_factors, scaling_factors,
regression_factors,
device=device, device=device,
exit_on_error=exit_on_error) exit_on_error=exit_on_error)
...@@ -366,31 +366,31 @@ if __name__ == '__main__': ...@@ -366,31 +366,31 @@ if __name__ == '__main__':
# Here we shuffle the data! # Here we shuffle the data!
if data_parameters['data_split_flag'] == True: if data_parameters['data_split_flag'] == True:
print('Data is shuffling... This could take a few minutes!') print('Data is shuffling... This could take a few minutes!')
if data_parameters['data_split_flag'] == True:
if data_parameters['use_data_file'] == True: if data_parameters['use_data_file'] == True:
data_test_train_validation_split(data_parameters['data_folder_name'], data_preparation(data_parameters['data_folder_name'],
data_parameters['test_percentage'], data_parameters['test_percentage'],
data_parameters['subject_number'], data_parameters['subject_number'],
data_directory=data_parameters['data_directory'], data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'], train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'], train_targets=data_parameters['train_output_targets'],
mean_mask_path=data_parameters['mean_mask_path'], rsfMRI_mean_mask_path=data_parameters['rsfmri_mean_mask_path'],
data_file=data_parameters['data_file'], dMRI_mean_mask_path=data_parameters['dmri_mean_mask_path'],
K_fold=data_parameters['k_fold'] data_file=data_parameters['data_file'],
) K_fold=data_parameters['k_fold']
)
else: else:
data_test_train_validation_split(data_parameters['data_folder_name'], data_preparation(data_parameters['data_folder_name'],
data_parameters['test_percentage'], data_parameters['test_percentage'],
data_parameters['subject_number'], data_parameters['subject_number'],
data_directory=data_parameters['data_directory'], data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'], train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'], train_targets=data_parameters['train_output_targets'],
mean_mask_path=data_parameters['mean_mask_path'], rsfMRI_mean_mask_path=data_parameters['rsfmri_mean_mask_path'],
K_fold=data_parameters['k_fold'] dMRI_mean_mask_path=data_parameters['dmri_mean_mask_path'],
) K_fold=data_parameters['k_fold']
)
update_shuffling_flag('settings.ini') update_shuffling_flag('settings.ini')
print('Data is shuffling... Complete!') print('Data is shuffling... Complete!')
...@@ -400,7 +400,6 @@ if __name__ == '__main__': ...@@ -400,7 +400,6 @@ if __name__ == '__main__':
network_parameters, misc_parameters) network_parameters, misc_parameters)
# NOTE: THE EVAL FUNCTIONS HAVE NOT YET BEEN DEBUGGED (16/04/20) # NOTE: THE EVAL FUNCTIONS HAVE NOT YET BEEN DEBUGGED (16/04/20)
# NOTE: THE EVAL-MAPPING FUNCTION HAS BEEN DEBUGGED (28/04/20)
elif arguments.mode == 'evaluate-score': elif arguments.mode == 'evaluate-score':
evaluate_score(training_parameters, evaluate_score(training_parameters,
...@@ -433,6 +432,9 @@ if __name__ == '__main__': ...@@ -433,6 +432,9 @@ if __name__ == '__main__':
network_parameters, misc_parameters) network_parameters, misc_parameters)
logging.basicConfig(filename='evaluate-mapping-error.log') logging.basicConfig(filename='evaluate-mapping-error.log')
evaluate_mapping(mapping_evaluation_parameters) evaluate_mapping(mapping_evaluation_parameters)
elif arguments.mode == 'prepare-data':
print('Ensure you have updated the settings.ini file accordingly! This call does nothing but pass after data was shuffled!')
pass
else: else:
raise ValueError( raise ValueError(
'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, train-and-evaluate-mapping, clear-experiments and clear-everything') 'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, train-and-evaluate-mapping, prepare-data, clear-experiments and clear-everything')
...@@ -6,31 +6,34 @@ data_file = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/subj_22k. ...@@ -6,31 +6,34 @@ data_file = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/subj_22k.
k_fold = None k_fold = None
data_split_flag = False data_split_flag = False
test_percentage = 5 test_percentage = 5
subject_number = 600 subject_number = 12000
train_list = "datasets/train.txt" train_list = "datasets/train.txt"
validation_list = "datasets/validation.txt" validation_list = "datasets/validation.txt"
test_list = "datasets/test.txt" test_list = "datasets/test.txt"
scaling_factors = "datasets/scaling_factors.pkl" scaling_factors = "datasets/scaling_factors.pkl"
regression_weights = "datasets/regression_weights.pkl"
train_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz" train_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
train_output_targets = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz" train_output_targets = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
validation_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz" validation_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
validation_target_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz" validation_target_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz" brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
mean_mask_path = "utils/mean_dr_stage2.nii.gz" rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
mean_reduction = True dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz"
mean_regression = False
mean_subtraction = True
[TRAINING] [TRAINING]
experiment_name = "CU3D17-3" experiment_name = "VA2-1"
pre_trained_path = "saved_models/CU3D17-3.pth.tar" pre_trained_path = "saved_models/VA2-1.pth.tar"
final_model_output_file = "CU3D17-3.pth.tar" final_model_output_file = "VA2-1.pth.tar"
training_batch_size = 3 training_batch_size = 5
validation_batch_size = 3 validation_batch_size = 5
use_pre_trained = False use_pre_trained = False
learning_rate = 1e-1 learning_rate = 1e-5
optimizer_beta = (0.9, 0.999) optimizer_beta = (0.9, 0.999)
optimizer_epsilon = 1e-8 optimizer_epsilon = 1e-8
optimizer_weigth_decay = 1e-5 optimizer_weigth_decay = 1e-5
number_of_epochs = 200 number_of_epochs = 10
loss_log_period = 50 loss_log_period = 50
learning_rate_scheduler_step_size = 5 learning_rate_scheduler_step_size = 5
learning_rate_scheduler_gamma = 1e-1 learning_rate_scheduler_gamma = 1e-1
...@@ -40,19 +43,15 @@ use_last_checkpoint = False ...@@ -40,19 +43,15 @@ use_last_checkpoint = False
kernel_heigth = 3 kernel_heigth = 3
kernel_width = 3 kernel_width = 3
kernel_depth = 3 kernel_depth = 3
; kernel_classification = 1 kernel_classification = 7
input_channels = 1 input_channels = 1
output_channels = 64 output_channels = 64
convolution_stride = 1 convolution_stride = 1
dropout = 0.2 dropout = 0
; pool_kernel_size = 2 pool_kernel_size = 3
pool_stride = 2 pool_stride = 2
up_mode = "upconv" up_mode = "upconv"
number_of_classes = 1 number_of_classes = 1
; ---> parameters for the ResNet CGAN
pool_kernel_size = 3
kernel_classification = 7
[MISC] [MISC]
save_model_directory = "saved_models" save_model_directory = "saved_models"
......
[MAPPING] [MAPPING]
trained_model_path = "saved_models/test24.pth.tar" trained_model_path = "saved_models/VA2.pth.tar"
prediction_output_path = "test24_predictions" prediction_output_path = "VA2_predictions"
data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/" data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz" mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
data_list = "datasets/test.txt" data_list = "datasets/test.txt"
brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz" brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz"
mean_mask_path = "utils/mean_dr_stage2.nii.gz" mean_mask_path = "utils/mean_dr_stage2.nii.gz"
scaling_factors = "datasets/scaling_factors.pkl" scaling_factors = "datasets/scaling_factors.pkl"
mean_reduction = True regression_factors = "datasets/regression_weights.pkl"
mean_regression = False
mean_subtraction = True
device = 0 device = 0
exit_on_error = True exit_on_error = True
...@@ -10,11 +10,9 @@ setup( ...@@ -10,11 +10,9 @@ setup(
install_requires=[ install_requires=[
'pip', 'pip',
'matplotlib', 'matplotlib',
'nibabel',
'numpy', 'numpy',
'pandas', 'pandas',
'torch==1.4', 'torch==1.4',
'h5py',
'fslpy', 'fslpy',
'tensorboardX', 'tensorboardX',
'sklearn', 'sklearn',
......
...@@ -63,6 +63,8 @@ class Solver(): ...@@ -63,6 +63,8 @@ class Solver():
optimizer, optimizer,
optimizer_arguments={}, optimizer_arguments={},
loss_function=MSELoss(), loss_function=MSELoss(),
# loss_function=torch.nn.L1Loss(),
# loss_function=torch.nn.CosineEmbeddingLoss(),
model_name='BrainMapper', model_name='BrainMapper',
labels=None, labels=None,
number_epochs=10, number_epochs=10,
...@@ -83,8 +85,10 @@ class Solver(): ...@@ -83,8 +85,10 @@ class Solver():
if torch.cuda.is_available(): if torch.cuda.is_available():
self.loss_function = loss_function.cuda(device) self.loss_function = loss_function.cuda(device)
self.MSE = MSELoss().cuda(device)
else: else:
self.loss_function = loss_function self.loss_function = loss_function
self.MSE = MSELoss()
self.model_name = model_name self.model_name = model_name
self.labels = labels self.labels = labels
...@@ -152,6 +156,7 @@ class Solver(): ...@@ -152,6 +156,7 @@ class Solver():
previous_checkpoint = None previous_checkpoint = None
previous_loss = None previous_loss = None
previous_MSE = None
print('****************************************************************') print('****************************************************************')
print('TRAINING IS STARTING!') print('TRAINING IS STARTING!')
...@@ -175,6 +180,7 @@ class Solver(): ...@@ -175,6 +180,7 @@ class Solver():
print('-> Phase: {}'.format(phase)) print('-> Phase: {}'.format(phase))
losses = [] losses = []
MSEs = []
if phase == 'train': if phase == 'train':
model.train() model.train()
...@@ -189,10 +195,8 @@ class Solver(): ...@@ -189,10 +195,8 @@ class Solver():
X = torch.unsqueeze(X, dim=1) X = torch.unsqueeze(X, dim=1)
y = torch.unsqueeze(y, dim=1) y = torch.unsqueeze(y, dim=1)
print('X range:', torch.min(X), torch.max(X)) MNI152_T1_2mm_brain_mask = torch.unsqueeze(
print('y range:', torch.min(y), torch.max(y)) torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
MNI152_T1_2mm_brain_mask = torch.unsqueeze(torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
if model.test_if_cuda: if model.test_if_cuda:
X = X.cuda(self.device, non_blocking=True) X = X.cuda(self.device, non_blocking=True)
...@@ -202,13 +206,14 @@ class Solver(): ...@@ -202,13 +206,14 @@ class Solver():
y_hat = model(X) # Forward pass & Masking y_hat = model(X) # Forward pass & Masking
print('y_hat range:', torch.min(y_hat), torch.max(y_hat))
y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask) y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
print('y_hat masked range:', torch.min(y_hat), torch.max(y_hat))
loss = self.loss_function(y_hat, y) # Loss computation loss = self.loss_function(y_hat, y) # Loss computation
# loss = self.loss_function(y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True))
# We also calculate a separate MSE for cost function comparison!
MSE = self.MSE(y_hat, y)
MSEs.append(MSE.item())
if phase == 'train': if phase == 'train':
optimizer.zero_grad() # Zero the parameter gradients optimizer.zero_grad() # Zero the parameter gradients
...@@ -226,7 +231,7 @@ class Solver(): ...@@ -226,7 +231,7 @@ class Solver():
# Clear the memory # Clear the memory
del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE
torch.cuda.empty_cache() torch.cuda.empty_cache()
if phase == 'validation': if phase == 'validation':
...@@ -240,9 +245,14 @@ class Solver(): ...@@ -240,9 +245,14 @@ class Solver():
if phase == 'train': if phase == 'train':
self.LogWriter.loss_per_epoch(losses, phase, epoch) self.LogWriter.loss_per_epoch(losses, phase, epoch)
self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
elif phase == 'validation': elif phase == 'validation':
self.LogWriter.loss_per_epoch(losses, phase, epoch, previous_loss=previous_loss) self.LogWriter.loss_per_epoch(
losses, phase, epoch, previous_loss=previous_loss)
previous_loss = np.mean(losses) previous_loss = np.mean(losses)
self.LogWriter.MSE_per_epoch(
MSEs, phase, epoch, previous_loss=previous_MSE)
previous_MSE = np.mean(MSEs)
if phase == 'validation': if phase == 'validation':
early_stop, save_checkpoint = self.EarlyStopping( early_stop, save_checkpoint = self.EarlyStopping(
...@@ -300,6 +310,9 @@ class Solver(): ...@@ -300,6 +310,9 @@ class Solver():
print('Final Model Saved in: {}'.format(model_output_path)) print('Final Model Saved in: {}'.format(model_output_path))
print('****************************************************************') print('****************************************************************')
if self.start_epoch >= self.number_epochs+1:
validation_loss = None
return validation_loss return validation_loss
def save_checkpoint(self, state, filename): def save_checkpoint(self, state, filename):
......
"""Biobank Data Stats Calculator
Description:
This file contains the relevant scripts for producing a database containing relevant statistics about the imaing data from the UK Biobank.
This is a standalone scrip, intended to be used only once during the project. Hence, it is not integrated into the larger utils packages.
Usage:
To use content from this folder, import the functions and instantiate them as you wish to use them:
from utils.DSbiobank import function_name
"""
import numpy as np
from fsl.data.image import Image
from fsl.utils.image.resample import resampleToPixdims
import matplotlib.pyplot as plt
from data_utils import directory_reader, regression_weight_calculator
from tempfile import TemporaryFile
from datetime import datetime
import pandas as pd
import os
def stats_calc(array):
""" Statistics calculator
Function calculating all the required statistics for every array
Args:
array (np.array): 3D array of subject data