Commit 52785b10 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

bug fixes

parent 39935616
......@@ -159,7 +159,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
device=misc_parameters['device'],
number_of_classes=network_parameters['number_of_classes'],
experiment_name=training_parameters['experiment_name'],
optimizer= optimizer,
optimizer=optimizer,
optimizer_arguments={'lr': training_parameters['learning_rate'],
'betas': training_parameters['optimizer_beta'],
'eps': training_parameters['optimizer_epsilon'],
......@@ -320,6 +320,8 @@ def evaluate_mapping(mapping_evaluation_parameters):
prediction_output_path,
batch_size,
device=device,
exit_on_error=exit_on_error)
def delete_files(folder):
""" Clear Folder Contents
......
......@@ -71,7 +71,7 @@ class Solver():
use_last_checkpoint=True,
experiment_directory='experiments',
logs_directory='logs',
checkpoint_directory = 'checkpoints'
checkpoint_directory='checkpoints'
):
self.model = model
......@@ -222,7 +222,8 @@ class Solver():
self.LogWriter.loss_per_epoch(losses, phase, epoch)
if phase == 'validation':
early_stop, save_checkpoint = self.EarlyStopping(np.mean(losses))
early_stop, save_checkpoint = self.EarlyStopping(
np.mean(losses))
self.early_stop = early_stop
if save_checkpoint == True:
validation_loss = np.mean(losses)
......@@ -237,7 +238,8 @@ class Solver():
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
if epoch != self.start_epoch:
os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension))
os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension))
if phase == 'train':
learning_rate_scheduler.step()
......
......@@ -262,21 +262,24 @@ def evaluate_mapping(trained_model_path,
# Initiate the evaluation
log.info("rsfMRI Generation Started")
file_paths = data_utils.load_file_paths(data_directory, data_list, mapping_data_file)
file_paths = data_utils.load_file_paths(
data_directory, data_list, mapping_data_file)
with torch.no_grad():
for volume_index, file_path in enumerate(file_paths):
try:
# Generate volume & header
_, predicted_volume, header, xform = _generate_volume_map(file_path, model, batch_size, device, cuda_available)
_, predicted_volume, header, xform = _generate_volume_map(
file_path, model, batch_size, device, cuda_available)
# Generate New Header Affine
header_affines = np.array(
[header['srow_x'], header['srow_y'], header['srow_z'], [0, 0, 0, 1]])
output_nifti_image = Image(predicted_volume, header=header, xform=xform)
output_nifti_image = Image(
predicted_volume, header=header, xform=xform)
output_nifti_path = os.path.join(
prediction_output_path, volumes_to_be_used[volume_index])
......@@ -321,7 +324,8 @@ def _generate_volume_map(file_path, model, batch_size, device, cuda_available):
header (class): 'nibabel.nifti1.Nifti1Header' class object, containing volume metadata
"""
volume, header, xform = data_utils.load_and_preprocess_evaluation(file_path)
volume, header, xform = data_utils.load_and_preprocess_evaluation(
file_path)
if len(volume.shape) == 4:
volume = volume
......
......@@ -109,7 +109,8 @@ def data_test_train_validation_split(data_folder_name, test_percentage, subject_
"""
if data_file is not None:
subDirectoryList = data_file_reader(data_file, data_directory, subject_number)
subDirectoryList = data_file_reader(
data_file, data_directory, subject_number)
else:
subDirectoryList = directory_reader(data_directory, subject_number)
......@@ -353,7 +354,7 @@ def load_file_paths(data_directory, data_list, mapping_data_file, targets_direct
file_paths = [[os.path.join(data_directory, volume, mapping_data_file)]
for volume in volumes_to_be_used]
else:
file_paths = [[os.path.join(data_directory, volume, mapping_data_file), os.join.path(
file_paths = [[os.path.join(data_directory, volume, mapping_data_file), os.path.join(
targets_directory, volume, )] for volume in volumes_to_be_used]
return file_paths
......@@ -469,11 +470,9 @@ def load_and_preprocess_evaluation(file_path, min_max=False):
return volume, header, xform
# Deprecated Functions & Classes & Methods:
def tract_sum_generator(folder_path):
"""Sums the tracts of different dMRI files
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment