Commit 09367ded authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

bug fixes for memory bleeds during cross validation

parent d15c6593
......@@ -207,6 +207,24 @@ class BrainMapperUNet3D(nn.Module):
return prediction
def reset_parameters(self):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.PReLU, torch.nn.Dropout3d, torch.nn.MaxPool3d)) == False:
submodule.reset_parameters()
print("Initialized network parameters!")
# DEPRECATED ARCHITECTURES!
......@@ -378,6 +396,23 @@ class BrainMapperUNet(nn.Module):
return prediction
def reset_parameters(self):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.PReLU, torch.nn.Dropout3d, torch.nn.MaxPool3d)) == False:
submodule.reset_parameters()
print("Initialized network parameters!")
class BrainMapperUNet3D_Simple(nn.Module):
"""Architecture class BrainMapper 3D U-net.
......@@ -550,6 +585,24 @@ class BrainMapperUNet3D_Simple(nn.Module):
return prediction
def reset_parameters(self):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.PReLU, torch.nn.Dropout3d, torch.nn.MaxPool3d)) == False:
submodule.reset_parameters()
print("Initialized network parameters!")
# if __name__ == '__main__':
# # For debugging - To be deleted later! TODO
......
......@@ -151,6 +151,8 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
else:
BrainMapperModel = BrainMapperUNet3D(network_parameters)
BrainMapperModel.reset_parameters()
solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'],
number_of_classes=network_parameters['number_of_classes'],
......@@ -195,11 +197,11 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
else:
print("Training initiated using K-fold Cross Validation!")
for k in range(data_parameters['k_fold']):
k_fold_losses = []
print("K-fold Number: {}".format(k+1))
for k in range(data_parameters['k_fold']):
k_fold_losses = []
print("K-fold Number: {}".format(k+1))
data_parameters['train_list'] = os.path.join(
data_parameters['data_folder_name'], 'train' + str(k+1)+'.txt')
......
......@@ -35,7 +35,10 @@ class Settings(Mapping):
def __init__(self, settings_file='settings.ini'):
configurator = configparser.ConfigParser()
configurator.read(settings_file)
if not configurator.read(settings_file):
configurator.read('functionmapper/'+settings_file)
else:
configurator.read(settings_file)
self.settings_dictionary = _parse_values(configurator)
def __getitem__(self, key):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment