Commit 2ab1cdfe authored by Andrei Roibu's avatar Andrei Roibu
Browse files

updated imports, removed k-fold + data shuffling

parent 159001d8
......@@ -33,7 +33,6 @@ import os
import shutil
import argparse
import logging
from settings import Settings
import torch
import torch.utils.data as data
......@@ -41,9 +40,12 @@ import numpy as np
from solver import Solver
from BrainMapperAE import BrainMapperAE3D
from utils.data_utils import get_datasets, data_preparation, update_shuffling_flag, create_folder
from utils.data_utils import get_datasets
from utils.settings import Settings
import utils.data_evaluation_utils as evaluations
from utils.data_logging_utils import LogWriter
from utils.common_utils import create_folder
from utils.preprocessor import data_preparation, update_shuffling_flag
# Set the default floating point tensor type to FloatTensor
......@@ -187,34 +189,8 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
return validation_loss
if data_parameters['k_fold'] is None:
_ = _train_runner(data_parameters, training_parameters,
network_parameters, misc_parameters)
else:
print("Training initiated using K-fold Cross Validation!")
k_fold_losses = []
for k in range(data_parameters['k_fold']):
print("K-fold Number: {}".format(k+1))
data_parameters['train_list'] = os.path.join(
data_parameters['data_folder_name'], 'train' + str(k+1)+'.txt')
data_parameters['validation_list'] = os.path.join(
data_parameters['data_folder_name'], 'validation' + str(k+1)+'.txt')
training_parameters['final_model_output_file'] = training_parameters['final_model_output_file'].replace(
".pth.tar", str(k+1)+".pth.tar")
validation_loss = _train_runner(
data_parameters, training_parameters, network_parameters, misc_parameters)
k_fold_losses.append(validation_loss)
for k in range(data_parameters['k_fold']):
print("K-fold Number: {} Loss: {}".format(k+1, k_fold_losses[k]))
print("K-fold Cross Validation Avearge Loss: {}".format(np.mean(k_fold_losses)))
_ = _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters)
def evaluate_score(training_parameters, network_parameters, misc_parameters, evaluation_parameters):
......@@ -363,38 +339,6 @@ if __name__ == '__main__':
misc_parameters = settings['MISC']
evaluation_parameters = settings['EVALUATION']
# Here we shuffle the data!
if data_parameters['data_split_flag'] == True:
print('Data is shuffling... This could take a few minutes!')
if data_parameters['use_data_file'] == True:
data_preparation(data_parameters['data_folder_name'],
data_parameters['test_percentage'],
data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'],
rsfMRI_mean_mask_path=data_parameters['rsfmri_mean_mask_path'],
dMRI_mean_mask_path=data_parameters['dmri_mean_mask_path'],
data_file=data_parameters['data_file'],
K_fold=data_parameters['k_fold']
)
else:
data_preparation(data_parameters['data_folder_name'],
data_parameters['test_percentage'],
data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'],
rsfMRI_mean_mask_path=data_parameters['rsfmri_mean_mask_path'],
dMRI_mean_mask_path=data_parameters['dmri_mean_mask_path'],
K_fold=data_parameters['k_fold']
)
update_shuffling_flag('settings.ini')
print('Data is shuffling... Complete!')
if arguments.mode == 'train':
train(data_parameters, training_parameters,
network_parameters, misc_parameters)
......@@ -432,9 +376,6 @@ if __name__ == '__main__':
network_parameters, misc_parameters)
logging.basicConfig(filename='evaluate-mapping-error.log')
evaluate_mapping(mapping_evaluation_parameters)
elif arguments.mode == 'prepare-data':
print('Ensure you have updated the settings.ini file accordingly! This call does nothing but pass after data was shuffled!')
pass
else:
raise ValueError(
'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, train-and-evaluate-mapping, prepare-data, clear-experiments and clear-everything')
'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, train-and-evaluate-mapping, clear-experiments and clear-everything')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment