Commit aa74b23e authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

Merge branch 'autoencoder_2' into 'master'

Autoencoder 2

See merge request !3
parents 92f1b83d ca61c821
...@@ -75,16 +75,7 @@ class BrainMapperAE3D(nn.Module): ...@@ -75,16 +75,7 @@ class BrainMapperAE3D(nn.Module):
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = original_stride parameters['convolution_stride'] = original_stride
self.transformerBlock1 = modules.ResNetBlock3D(parameters) self.transformerBlocks = nn.ModuleList([modules.ResNetBlock3D(parameters) for i in range(parameters['number_of_transformer_blocks'])])
self.transformerBlock2 = modules.ResNetBlock3D(parameters)
self.transformerBlock3 = modules.ResNetBlock3D(parameters)
self.transformerBlock4 = modules.ResNetBlock3D(parameters)
self.transformerBlock5 = modules.ResNetBlock3D(parameters)
self.transformerBlock6 = modules.ResNetBlock3D(parameters)
self.transformerBlock7 = modules.ResNetBlock3D(parameters)
self.transformerBlock8 = modules.ResNetBlock3D(parameters)
self.transformerBlock9 = modules.ResNetBlock3D(parameters)
self.transformerBlock10 = modules.ResNetBlock3D(parameters)
# Decoder # Decoder
...@@ -124,16 +115,8 @@ class BrainMapperAE3D(nn.Module): ...@@ -124,16 +115,8 @@ class BrainMapperAE3D(nn.Module):
# Transformer # Transformer
X = self.transformerBlock1.forward(X) for transformerBlock in self.transformerBlocks:
X = self.transformerBlock2.forward(X) X = transformerBlock(X)
X = self.transformerBlock3.forward(X)
X = self.transformerBlock4.forward(X)
X = self.transformerBlock5.forward(X)
X = self.transformerBlock6.forward(X)
X = self.transformerBlock7.forward(X)
X = self.transformerBlock8.forward(X)
X = self.transformerBlock9.forward(X)
X = self.transformerBlock10.forward(X)
# Decoder # Decoder
...@@ -207,26 +190,39 @@ class BrainMapperAE3D(nn.Module): ...@@ -207,26 +190,39 @@ class BrainMapperAE3D(nn.Module):
return prediction return prediction
def reset_parameters(self): def reset_parameters(self, custom_weight_reset_flag):
"""Parameter Initialization """Parameter Initialization
This function (re)initializes the parameters of the defined network. This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module. This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859 More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639 An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
Args:
custom_weight_reset_flag (bool): Flag indicating if the modified weight initialisation approach should be used.
""" """
print("Initializing network parameters...") print("Initializing network parameters...")
for _, module in self.named_children(): for _, module in self.named_children():
for _, submodule in module.named_children(): for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
submodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(submodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(submodule.weight)
std = np.divide(gain, np.sqrt(fan))
submodule.weight.data.normal_(0, std)
for _, subsubmodule in submodule.named_children(): for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.PReLU, torch.nn.Dropout3d, torch.nn.MaxPool3d)) == False: if isinstance(subsubmodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
subsubmodule.reset_parameters() subsubmodule.reset_parameters()
# if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)): if custom_weight_reset_flag == True:
# gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2))) if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
# fan, _ = calculate_fan(subsubmodule.weight) gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
# std = np.divide(gain, np.sqrt(fan)) fan, _ = calculate_fan(subsubmodule.weight)
# subsubmodule.weight.data.normal_(0, std) std = np.divide(gain, np.sqrt(fan))
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!") print("Initialized network parameters!")
...@@ -33,7 +33,6 @@ import os ...@@ -33,7 +33,6 @@ import os
import shutil import shutil
import argparse import argparse
import logging import logging
from settings import Settings
import torch import torch
import torch.utils.data as data import torch.utils.data as data
...@@ -41,9 +40,11 @@ import numpy as np ...@@ -41,9 +40,11 @@ import numpy as np
from solver import Solver from solver import Solver
from BrainMapperAE import BrainMapperAE3D from BrainMapperAE import BrainMapperAE3D
from utils.data_utils import get_datasets, data_preparation, update_shuffling_flag, create_folder from utils.data_utils import get_datasets
from utils.settings import Settings
import utils.data_evaluation_utils as evaluations import utils.data_evaluation_utils as evaluations
from utils.data_logging_utils import LogWriter from utils.data_logging_utils import LogWriter
from utils.common_utils import create_folder
# Set the default floating point tensor type to FloatTensor # Set the default floating point tensor type to FloatTensor
...@@ -133,7 +134,6 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -133,7 +134,6 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
dataset=train_data, dataset=train_data,
batch_size=training_parameters['training_batch_size'], batch_size=training_parameters['training_batch_size'],
shuffle=True, shuffle=True,
num_workers=4,
pin_memory=True pin_memory=True
) )
...@@ -141,7 +141,6 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -141,7 +141,6 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
dataset=validation_data, dataset=validation_data,
batch_size=training_parameters['validation_batch_size'], batch_size=training_parameters['validation_batch_size'],
shuffle=False, shuffle=False,
num_workers=4,
pin_memory=True pin_memory=True
) )
...@@ -151,10 +150,14 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -151,10 +150,14 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
else: else:
BrainMapperModel = BrainMapperAE3D(network_parameters) BrainMapperModel = BrainMapperAE3D(network_parameters)
BrainMapperModel.reset_parameters() custom_weight_reset_flag = network_parameters['custom_weight_reset_flag']
optimizer = torch.optim.Adam BrainMapperModel.reset_parameters(custom_weight_reset_flag)
# optimizer = torch.optim.AdamW
if training_parameters['adam_w_flag'] == True:
optimizer = torch.optim.AdamW
else:
optimizer = torch.optim.Adam
solver = Solver(model=BrainMapperModel, solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'], device=misc_parameters['device'],
...@@ -177,7 +180,8 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -177,7 +180,8 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
logs_directory=misc_parameters['logs_directory'], logs_directory=misc_parameters['logs_directory'],
checkpoint_directory=misc_parameters['checkpoint_directory'], checkpoint_directory=misc_parameters['checkpoint_directory'],
save_model_directory=misc_parameters['save_model_directory'], save_model_directory=misc_parameters['save_model_directory'],
final_model_output_file=training_parameters['final_model_output_file'] final_model_output_file=training_parameters['final_model_output_file'],
crop_flag = data_parameters['crop_flag']
) )
validation_loss = solver.train(train_loader, validation_loader) validation_loss = solver.train(train_loader, validation_loader)
...@@ -187,94 +191,8 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -187,94 +191,8 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
return validation_loss return validation_loss
if data_parameters['k_fold'] is None:
_ = _train_runner(data_parameters, training_parameters,
network_parameters, misc_parameters)
else:
print("Training initiated using K-fold Cross Validation!")
k_fold_losses = []
for k in range(data_parameters['k_fold']):
print("K-fold Number: {}".format(k+1))
data_parameters['train_list'] = os.path.join(
data_parameters['data_folder_name'], 'train' + str(k+1)+'.txt')
data_parameters['validation_list'] = os.path.join(
data_parameters['data_folder_name'], 'validation' + str(k+1)+'.txt')
training_parameters['final_model_output_file'] = training_parameters['final_model_output_file'].replace(
".pth.tar", str(k+1)+".pth.tar")
validation_loss = _train_runner(
data_parameters, training_parameters, network_parameters, misc_parameters)
k_fold_losses.append(validation_loss)
for k in range(data_parameters['k_fold']):
print("K-fold Number: {} Loss: {}".format(k+1, k_fold_losses[k]))
print("K-fold Cross Validation Avearge Loss: {}".format(np.mean(k_fold_losses)))
_ = _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters)
def evaluate_score(training_parameters, network_parameters, misc_parameters, evaluation_parameters):
"""Mapping Score Evaluator
This function evaluates a given trained model by calculating the it's dice score prediction.
Args:
training_parameters(dict): Dictionary containing relevant hyperparameters for training the network.
training_parameters = {
'experiment_name': 'experiment_name'
}
network_parameters (dict): Contains information relevant parameters
network_parameters= {
'number_of_classes': 1
}
misc_parameters (dict): Dictionary of aditional hyperparameters
misc_parameters = {
'logs_directory': 'log-directory'
'device': 1
'experiments_directory': 'experiments-directory'
}
evaluation_parameters (dict): Dictionary of parameters useful during evaluation.
evaluation_parameters = {
'trained_model_path': 'path/to/model'
'data_directory': 'path/to/data'
'targets_directory': 'path/to/targets'
'data_list': 'path/to/datalist.txt/
'orientation': 'coronal'
'saved_predictions_directory': 'directory-of-saved-predictions'
}
"""
# TODO - NEED TO UPDATE THE DATA FUNCTIONS!
prediction_output_path = os.path.join(misc_parameters['experiments_directory'],
training_parameters['experiment_name'],
evaluation_parameters['saved_predictions_directory']
)
evaluations.evaluate_correlation(trained_model_path=evaluation_parameters['trained_model_path'],
data_directory=evaluation_parameters['data_directory'],
mapping_data_file=mapping_evaluation_parameters['mapping_data_file'],
target_data_file=evaluation_parameters['targets_directory'],
data_list=evaluation_parameters['data_list'],
prediction_output_path=prediction_output_path,
brain_mask_path=mapping_evaluation_parameters['brain_mask_path'],
rsfmri_mean_mask_path=mapping_evaluation_parameters[
'rsfmri_mean_mask_path'],
dmri_mean_mask_path=mapping_evaluation_parameters[
'dmri_mean_mask_path'],
mean_regression=mapping_evaluation_parameters['mean_regression'],
scaling_factors=mapping_evaluation_parameters['scaling_factors'],
regression_factors=mapping_evaluation_parameters['regression_factors'],
device=misc_parameters['device'],
)
def evaluate_mapping(mapping_evaluation_parameters): def evaluate_mapping(mapping_evaluation_parameters):
...@@ -305,10 +223,18 @@ def evaluate_mapping(mapping_evaluation_parameters): ...@@ -305,10 +223,18 @@ def evaluate_mapping(mapping_evaluation_parameters):
device = mapping_evaluation_parameters['device'] device = mapping_evaluation_parameters['device']
exit_on_error = mapping_evaluation_parameters['exit_on_error'] exit_on_error = mapping_evaluation_parameters['exit_on_error']
brain_mask_path = mapping_evaluation_parameters['brain_mask_path'] brain_mask_path = mapping_evaluation_parameters['brain_mask_path']
mean_regression = mapping_evaluation_parameters['mean_regression']
mean_subtraction = mapping_evaluation_parameters['mean_subtraction']
scaling_factors = mapping_evaluation_parameters['scaling_factors']
regression_factors = mapping_evaluation_parameters['regression_factors'] regression_factors = mapping_evaluation_parameters['regression_factors']
mean_regression_flag = mapping_evaluation_parameters['mean_regression_flag']
mean_regression_all_flag = mapping_evaluation_parameters['mean_regression_all_flag']
mean_subtraction_flag = mapping_evaluation_parameters['mean_subtraction_flag']
scale_volumes_flag = mapping_evaluation_parameters['scale_volumes_flag']
normalize_flag = mapping_evaluation_parameters['normalize_flag']
minus_one_scaling_flag = mapping_evaluation_parameters['minus_one_scaling_flag']
negative_flag = mapping_evaluation_parameters['negative_flag']
outlier_flag = mapping_evaluation_parameters['outlier_flag']
shrinkage_flag = mapping_evaluation_parameters['shrinkage_flag']
hard_shrinkage_flag = mapping_evaluation_parameters['hard_shrinkage_flag']
crop_flag = mapping_evaluation_parameters['crop_flag']
evaluations.evaluate_mapping(trained_model_path, evaluations.evaluate_mapping(trained_model_path,
data_directory, data_directory,
...@@ -318,13 +244,20 @@ def evaluate_mapping(mapping_evaluation_parameters): ...@@ -318,13 +244,20 @@ def evaluate_mapping(mapping_evaluation_parameters):
brain_mask_path, brain_mask_path,
dmri_mean_mask_path, dmri_mean_mask_path,
rsfmri_mean_mask_path, rsfmri_mean_mask_path,
mean_regression,
mean_subtraction,
scaling_factors,
regression_factors, regression_factors,
device=device, mean_regression_flag,
exit_on_error=exit_on_error) mean_regression_all_flag,
mean_subtraction_flag,
scale_volumes_flag,
normalize_flag,
minus_one_scaling_flag,
negative_flag,
outlier_flag,
shrinkage_flag,
hard_shrinkage_flag,
crop_flag,
device,
exit_on_error)
def delete_files(folder): def delete_files(folder):
""" Clear Folder Contents """ Clear Folder Contents
...@@ -351,49 +284,32 @@ if __name__ == '__main__': ...@@ -351,49 +284,32 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--mode', '-m', required=True, parser.add_argument('--mode', '-m', required=True,
help='run mode, valid values are train or evaluate') help='run mode, valid values are train or evaluate')
parser.add_argument('--settings_path', '-sp', required=False, parser.add_argument('--model_name', '-n', required=True,
help='optional argument, set path to settings_evaluation.ini') help='model name, required for identifying the settings file modelName.ini & modelName_eval.ini')
parser.add_argument('--use_last_checkpoint', '-c', required=False,
help='flag indicating if the last checkpoint should be used if 1; useful when wanting to time-limit jobs.')
parser.add_argument('--number_of_epochs', '-e', required=False,
help='flag indicating how many epochs the network will train for; should be limited to ~3 hours or 2/3 epochs')
arguments = parser.parse_args() arguments = parser.parse_args()
settings = Settings('settings.ini') settings_file_name = arguments.model_name + '.ini'
evaluation_settings_file_name = arguments.model_name + '_eval.ini'
settings = Settings(settings_file_name)
data_parameters = settings['DATA'] data_parameters = settings['DATA']
training_parameters = settings['TRAINING'] training_parameters = settings['TRAINING']
network_parameters = settings['NETWORK'] network_parameters = settings['NETWORK']
misc_parameters = settings['MISC'] misc_parameters = settings['MISC']
evaluation_parameters = settings['EVALUATION'] evaluation_parameters = settings['EVALUATION']
# Here we shuffle the data! if arguments.use_last_checkpoint == '1':
training_parameters['use_last_checkpoint'] = True
if data_parameters['data_split_flag'] == True: elif arguments.use_last_checkpoint == '0':
print('Data is shuffling... This could take a few minutes!') training_parameters['use_last_checkpoint'] = False
if data_parameters['use_data_file'] == True: if arguments.number_of_epochs is not None:
data_preparation(data_parameters['data_folder_name'], training_parameters['number_of_epochs'] = int(arguments.number_of_epochs)
data_parameters['test_percentage'],
data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'],
rsfMRI_mean_mask_path=data_parameters['rsfmri_mean_mask_path'],
dMRI_mean_mask_path=data_parameters['dmri_mean_mask_path'],
data_file=data_parameters['data_file'],
K_fold=data_parameters['k_fold']
)
else:
data_preparation(data_parameters['data_folder_name'],
data_parameters['test_percentage'],
data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'],
rsfMRI_mean_mask_path=data_parameters['rsfmri_mean_mask_path'],
dMRI_mean_mask_path=data_parameters['dmri_mean_mask_path'],
K_fold=data_parameters['k_fold']
)
update_shuffling_flag('settings.ini')
print('Data is shuffling... Complete!')
if arguments.mode == 'train': if arguments.mode == 'train':
train(data_parameters, training_parameters, train(data_parameters, training_parameters,
...@@ -401,15 +317,9 @@ if __name__ == '__main__': ...@@ -401,15 +317,9 @@ if __name__ == '__main__':
# NOTE: THE EVAL FUNCTIONS HAVE NOT YET BEEN DEBUGGED (16/04/20) # NOTE: THE EVAL FUNCTIONS HAVE NOT YET BEEN DEBUGGED (16/04/20)
elif arguments.mode == 'evaluate-score':
evaluate_score(training_parameters,
network_parameters, misc_parameters, evaluation_parameters)
elif arguments.mode == 'evaluate-mapping': elif arguments.mode == 'evaluate-mapping':
logging.basicConfig(filename='evaluate-mapping-error.log') logging.basicConfig(filename='evaluate-mapping-error.log')
if arguments.settings_path is not None: settings_evaluation = Settings(evaluation_settings_file_name)
settings_evaluation = Settings(arguments.settings_path)
else:
settings_evaluation = Settings('settings_evaluation.ini')
mapping_evaluation_parameters = settings_evaluation['MAPPING'] mapping_evaluation_parameters = settings_evaluation['MAPPING']
evaluate_mapping(mapping_evaluation_parameters) evaluate_mapping(mapping_evaluation_parameters)
elif arguments.mode == 'clear-experiments': elif arguments.mode == 'clear-experiments':
...@@ -423,18 +333,12 @@ if __name__ == '__main__': ...@@ -423,18 +333,12 @@ if __name__ == '__main__':
delete_files(misc_parameters['logs_directory']) delete_files(misc_parameters['logs_directory'])
print('Cleared the current experiments and logs directory successfully!') print('Cleared the current experiments and logs directory successfully!')
elif arguments.mode == 'train-and-evaluate-mapping': elif arguments.mode == 'train-and-evaluate-mapping':
if arguments.settings_path is not None: settings_evaluation = Settings(evaluation_settings_file_name)
settings_evaluation = Settings(arguments.settings_path)
else:
settings_evaluation = Settings('settings_evaluation.ini')
mapping_evaluation_parameters = settings_evaluation['MAPPING'] mapping_evaluation_parameters = settings_evaluation['MAPPING']
train(data_parameters, training_parameters, train(data_parameters, training_parameters,
network_parameters, misc_parameters) network_parameters, misc_parameters)
logging.basicConfig(filename='evaluate-mapping-error.log') logging.basicConfig(filename='evaluate-mapping-error.log')
evaluate_mapping(mapping_evaluation_parameters) evaluate_mapping(mapping_evaluation_parameters)
elif arguments.mode == 'prepare-data':
print('Ensure you have updated the settings.ini file accordingly! This call does nothing but pass after data was shuffled!')
pass
else: else:
raise ValueError( raise ValueError(
'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, train-and-evaluate-mapping, prepare-data, clear-experiments and clear-everything') 'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, train-and-evaluate-mapping, clear-experiments and clear-everything')
[DATA] [DATA]
data_folder_name = "datasets" data_folder_name = "datasets"
use_data_file = False input_data_train = "input_data_train.h5"
data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/" target_data_train = "target_data_train.h5"
data_file = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/subj_22k.txt" input_data_validation = "input_data_validation.h5"
k_fold = None target_data_validation = "target_data_validation.h5"
data_split_flag = False crop_flag = False
test_percentage = 5
subject_number = 12000
train_list = "datasets/train.txt"
validation_list = "datasets/validation.txt"
test_list = "datasets/test.txt"
scaling_factors = "datasets/scaling_factors.pkl"
regression_weights = "datasets/regression_weights.pkl"
train_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
train_output_targets = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
validation_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
validation_target_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz"
mean_regression = False
mean_subtraction = True
[TRAINING] [TRAINING]
experiment_name = "VA2-1" experiment_name = "VA2-1"
pre_trained_path = "saved_models/VA2-1.pth.tar" pre_trained_path = "saved_models/VA2-1.pth.tar"
final_model_output_file = "VA2-1.pth.tar" final_model_output_file = "VA2-1.pth.tar"
training_batch_size = 5 training_batch_size = 3
validation_batch_size = 5 validation_batch_size = 3
use_pre_trained = False use_pre_trained = False
learning_rate = 1e-5 learning_rate = 1e-5
optimizer_beta = (0.9, 0.999) optimizer_beta = (0.9, 0.999)
...@@ -35,9 +19,10 @@ optimizer_epsilon = 1e-8 ...@@ -35,9 +19,10 @@ optimizer_epsilon = 1e-8
optimizer_weigth_decay = 1e-5 optimizer_weigth_decay = 1e-5
number_of_epochs = 10 number_of_epochs = 10
loss_log_period = 50 loss_log_period = 50
learning_rate_scheduler_step_size = 5 learning_rate_scheduler_step_size = 6
learning_rate_scheduler_gamma = 1e-1 learning_rate_scheduler_gamma = 1e-1
use_last_checkpoint = False use_last_checkpoint = False
adam_w_flag = False
[NETWORK] [NETWORK]
kernel_heigth = 3 kernel_heigth = 3
...@@ -45,13 +30,16 @@ kernel_width = 3 ...@@ -45,13 +30,16 @@ kernel_width = 3
kernel_depth = 3 kernel_depth = 3
kernel_classification = 7 kernel_classification = 7
input_channels = 1 input_channels = 1
output_channels = 64 output_channels = 32
convolution_stride = 1 convolution_stride = 1
dropout = 0 dropout = 0
pool_kernel_size = 3 pool_kernel_size = 3
pool_stride = 2 pool_stride = 2
up_mode = "upconv" up_mode = "upconv"
final_activation = 'tanh'
number_of_classes = 1 number_of_classes = 1
number_of_transformer_blocks = 6
custom_weight_reset_flag = False
[MISC] [MISC]
save_model_directory = "saved_models" save_model_directory = "saved_models"
......
[MAPPING] [MAPPING]
trained_model_path = "saved_models/VA2.pth.tar" trained_model_path = "saved_models/VA2-1.pth.tar"
prediction_output_path = "VA2_predictions" prediction_output_path = "VA2-1_predictions"
data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/" data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz" mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
data_list = "datasets/test.txt" data_list = "datasets/test.txt"
...@@ -8,9 +8,17 @@ brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz" ...@@ -8,9 +8,17 @@ brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz" rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz" dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz"
mean_mask_path = "utils/mean_dr_stage2.nii.gz" mean_mask_path = "utils/mean_dr_stage2.nii.gz"
scaling_factors = "datasets/scaling_factors.pkl"
regression_factors = "datasets/regression_weights.pkl" regression_factors = "datasets/regression_weights.pkl"
mean_regression = False mean_regression_flag = True
mean_subtraction = True mean_regression_all_flag = True
mean_subtraction_flag = False
scale_volumes_flag = True
normalize_flag = False
minus_one_scaling_flag = True
negative_flag = False
outlier_flag = True
shrinkage_flag = False
hard_shrinkage_flag = False
crop_flag = True
device = 0 device = 0
exit_on_error = True exit_on_error = True
\ No newline at end of file
...@@ -16,5 +16,7 @@ setup( ...@@ -16,5 +16,7 @@ setup(
'fslpy', 'fslpy',
'tensorboardX', 'tensorboardX',
'sklearn', 'sklearn',
'nibabel',
'h5py',
],