Commit 52785b10 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

bug fixes

parent 39935616
...@@ -159,7 +159,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -159,7 +159,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
device=misc_parameters['device'], device=misc_parameters['device'],
number_of_classes=network_parameters['number_of_classes'], number_of_classes=network_parameters['number_of_classes'],
experiment_name=training_parameters['experiment_name'], experiment_name=training_parameters['experiment_name'],
optimizer= optimizer, optimizer=optimizer,
optimizer_arguments={'lr': training_parameters['learning_rate'], optimizer_arguments={'lr': training_parameters['learning_rate'],
'betas': training_parameters['optimizer_beta'], 'betas': training_parameters['optimizer_beta'],
'eps': training_parameters['optimizer_epsilon'], 'eps': training_parameters['optimizer_epsilon'],
...@@ -204,7 +204,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -204,7 +204,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
for k in range(data_parameters['k_fold']): for k in range(data_parameters['k_fold']):
print("K-fold Number: {}".format(k+1)) print("K-fold Number: {}".format(k+1))
data_parameters['train_list'] = os.path.join( data_parameters['train_list'] = os.path.join(
data_parameters['data_folder_name'], 'train' + str(k+1)+'.txt') data_parameters['data_folder_name'], 'train' + str(k+1)+'.txt')
...@@ -314,12 +314,14 @@ def evaluate_mapping(mapping_evaluation_parameters): ...@@ -314,12 +314,14 @@ def evaluate_mapping(mapping_evaluation_parameters):
exit_on_error = mapping_evaluation_parameters['exit_on_error'] exit_on_error = mapping_evaluation_parameters['exit_on_error']
evaluations.evaluate_mapping(trained_model_path, evaluations.evaluate_mapping(trained_model_path,
data_directory, data_directory,
mapping_data_file, mapping_data_file,
data_list, data_list,
prediction_output_path, prediction_output_path,
batch_size, batch_size,
device=device, device=device,
exit_on_error=exit_on_error)
def delete_files(folder): def delete_files(folder):
""" Clear Folder Contents """ Clear Folder Contents
......
...@@ -71,7 +71,7 @@ class Solver(): ...@@ -71,7 +71,7 @@ class Solver():
use_last_checkpoint=True, use_last_checkpoint=True,
experiment_directory='experiments', experiment_directory='experiments',
logs_directory='logs', logs_directory='logs',
checkpoint_directory = 'checkpoints' checkpoint_directory='checkpoints'
): ):
self.model = model self.model = model
...@@ -203,14 +203,14 @@ class Solver(): ...@@ -203,14 +203,14 @@ class Solver():
iteration += 1 iteration += 1
losses.append(loss.item()) losses.append(loss.item())
# Clear the memory # Clear the memory
del X, y, y_hat, loss del X, y, y_hat, loss
torch.cuda.empty_cache() torch.cuda.empty_cache()
if phase == 'validation': if phase == 'validation':
if batch_index != len(dataloaders[phase]) - 1: if batch_index != len(dataloaders[phase]) - 1:
print("#", end='', flush=True) print("#", end='', flush=True)
...@@ -222,7 +222,8 @@ class Solver(): ...@@ -222,7 +222,8 @@ class Solver():
self.LogWriter.loss_per_epoch(losses, phase, epoch) self.LogWriter.loss_per_epoch(losses, phase, epoch)
if phase == 'validation': if phase == 'validation':
early_stop, save_checkpoint = self.EarlyStopping(np.mean(losses)) early_stop, save_checkpoint = self.EarlyStopping(
np.mean(losses))
self.early_stop = early_stop self.early_stop = early_stop
if save_checkpoint == True: if save_checkpoint == True:
validation_loss = np.mean(losses) validation_loss = np.mean(losses)
...@@ -233,11 +234,12 @@ class Solver(): ...@@ -233,11 +234,12 @@ class Solver():
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict() 'scheduler': learning_rate_scheduler.state_dict()
}, },
filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory, filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
) )
if epoch != self.start_epoch: if epoch != self.start_epoch:
os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension)) os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension))
if phase == 'train': if phase == 'train':
learning_rate_scheduler.step() learning_rate_scheduler.step()
......
...@@ -202,14 +202,14 @@ def evaluate_dice_score(trained_model_path, ...@@ -202,14 +202,14 @@ def evaluate_dice_score(trained_model_path,
def evaluate_mapping(trained_model_path, def evaluate_mapping(trained_model_path,
data_directory, data_directory,
mapping_data_file, mapping_data_file,
data_list, data_list,
prediction_output_path, prediction_output_path,
batch_size, batch_size,
device=0, device=0,
mode='evaluate', mode='evaluate',
exit_on_error=False): exit_on_error=False):
"""Model Evaluator """Model Evaluator
This function generates the rsfMRI map for an input running on on a single axis or path This function generates the rsfMRI map for an input running on on a single axis or path
...@@ -262,21 +262,24 @@ def evaluate_mapping(trained_model_path, ...@@ -262,21 +262,24 @@ def evaluate_mapping(trained_model_path,
# Initiate the evaluation # Initiate the evaluation
log.info("rsfMRI Generation Started") log.info("rsfMRI Generation Started")
file_paths = data_utils.load_file_paths(data_directory, data_list, mapping_data_file) file_paths = data_utils.load_file_paths(
data_directory, data_list, mapping_data_file)
with torch.no_grad(): with torch.no_grad():
for volume_index, file_path in enumerate(file_paths): for volume_index, file_path in enumerate(file_paths):
try: try:
# Generate volume & header # Generate volume & header
_, predicted_volume, header, xform = _generate_volume_map(file_path, model, batch_size, device, cuda_available) _, predicted_volume, header, xform = _generate_volume_map(
file_path, model, batch_size, device, cuda_available)
# Generate New Header Affine # Generate New Header Affine
header_affines = np.array( header_affines = np.array(
[header['srow_x'], header['srow_y'], header['srow_z'], [0, 0, 0, 1]]) [header['srow_x'], header['srow_y'], header['srow_z'], [0, 0, 0, 1]])
output_nifti_image = Image(predicted_volume, header=header, xform=xform) output_nifti_image = Image(
predicted_volume, header=header, xform=xform)
output_nifti_path = os.path.join( output_nifti_path = os.path.join(
prediction_output_path, volumes_to_be_used[volume_index]) prediction_output_path, volumes_to_be_used[volume_index])
...@@ -321,7 +324,8 @@ def _generate_volume_map(file_path, model, batch_size, device, cuda_available): ...@@ -321,7 +324,8 @@ def _generate_volume_map(file_path, model, batch_size, device, cuda_available):
header (class): 'nibabel.nifti1.Nifti1Header' class object, containing volume metadata header (class): 'nibabel.nifti1.Nifti1Header' class object, containing volume metadata
""" """
volume, header, xform = data_utils.load_and_preprocess_evaluation(file_path) volume, header, xform = data_utils.load_and_preprocess_evaluation(
file_path)
if len(volume.shape) == 4: if len(volume.shape) == 4:
volume = volume volume = volume
...@@ -797,4 +801,4 @@ def _generate_volume(file_path, model, orientation, batch_size, device, cuda_ava ...@@ -797,4 +801,4 @@ def _generate_volume(file_path, model, orientation, batch_size, device, cuda_ava
predicted_volume = predicted_volume predicted_volume = predicted_volume
output_volume = output_volume output_volume = output_volume
return output_volume, predicted_volume, header return output_volume, predicted_volume, header
\ No newline at end of file
...@@ -109,7 +109,8 @@ def data_test_train_validation_split(data_folder_name, test_percentage, subject_ ...@@ -109,7 +109,8 @@ def data_test_train_validation_split(data_folder_name, test_percentage, subject_
""" """
if data_file is not None: if data_file is not None:
subDirectoryList = data_file_reader(data_file, data_directory, subject_number) subDirectoryList = data_file_reader(
data_file, data_directory, subject_number)
else: else:
subDirectoryList = directory_reader(data_directory, subject_number) subDirectoryList = directory_reader(data_directory, subject_number)
...@@ -353,7 +354,7 @@ def load_file_paths(data_directory, data_list, mapping_data_file, targets_direct ...@@ -353,7 +354,7 @@ def load_file_paths(data_directory, data_list, mapping_data_file, targets_direct
file_paths = [[os.path.join(data_directory, volume, mapping_data_file)] file_paths = [[os.path.join(data_directory, volume, mapping_data_file)]
for volume in volumes_to_be_used] for volume in volumes_to_be_used]
else: else:
file_paths = [[os.path.join(data_directory, volume, mapping_data_file), os.join.path( file_paths = [[os.path.join(data_directory, volume, mapping_data_file), os.path.join(
targets_directory, volume, )] for volume in volumes_to_be_used] targets_directory, volume, )] for volume in volumes_to_be_used]
return file_paths return file_paths
...@@ -469,11 +470,9 @@ def load_and_preprocess_evaluation(file_path, min_max=False): ...@@ -469,11 +470,9 @@ def load_and_preprocess_evaluation(file_path, min_max=False):
return volume, header, xform return volume, header, xform
# Deprecated Functions & Classes & Methods: # Deprecated Functions & Classes & Methods:
def tract_sum_generator(folder_path): def tract_sum_generator(folder_path):
"""Sums the tracts of different dMRI files """Sums the tracts of different dMRI files
...@@ -646,4 +645,4 @@ def load_and_preprocess_evaluation2D(file_path, orientation, min_max=True): ...@@ -646,4 +645,4 @@ def load_and_preprocess_evaluation2D(file_path, orientation, min_max=True):
raise ValueError( raise ValueError(
"Orientation value is invalid. It must be either >>coronal<<, >>axial<< or >>sagital<< ") "Orientation value is invalid. It must be either >>coronal<<, >>axial<< or >>sagital<< ")
return volume, header return volume, header
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment