Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Vaanathi Sundaresan
truenet_tumseg
Commits
b1310e7b
Commit
b1310e7b
authored
May 24, 2021
by
Vaanathi Sundaresan
Browse files
train function call
parent
97165c6f
Changes
1
Hide whitespace changes
Inline
Side-by-side
truenet_tumseg/truenet_tumorseg/truenet_tumseg_train_function.py
0 → 100644
View file @
b1310e7b
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
torch
import
torch.nn
as
nn
from
torch
import
optim
from
truenet_tumseg.truenet_tumorseg
import
(
truenet_tumseg_loss_functions
,
truenet_tumseg_model
,
truenet_tumseg_train
)
from
truenet_tumseg.utils
import
truenet_tumseg_utils
################################################################################################
# Truenet_tumseg main training function
# Vaanathi Sundaresan
# 01-05-2021, Oxford
################################################################################################
def
main
(
sub_name_dicts
,
tr_params
,
aug
=
True
,
save_cp
=
True
,
save_wei
=
True
,
save_case
=
'last'
,
verbose
=
True
,
dir_cp
=
None
):
"""
The main training function
:param sub_name_dicts: list of dictionaries containing training filpaths
:param tr_params: dictionary of training parameters
:param aug: bool, perform data augmentation
:param save_cp: bool, save checkpoints
:param save_wei: bool, if False, the whole model will be saved
:param save_case: str, condition for saving the checkpoint
:param verbose: bool, display debug messages
:param dir_cp: str, directory for saving model/weights
:return: trained model
"""
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
assert
len
(
sub_name_dicts
)
>=
5
,
"Number of distinct subjects for training cannot be less than 5"
optim_type
=
tr_params
[
'Optimizer'
]
# adam, sgd
milestones
=
tr_params
[
'LR_Milestones'
]
# list of integers [1, N]
gamma
=
tr_params
[
'LR_red_factor'
]
# scalar (0,1)
lrt
=
tr_params
[
'Learning_rate'
]
# scalar (0,1)
req_plane
=
tr_params
[
'Acq_plane'
]
# string ('axial', 'sagittal', 'coronal', 'all')
train_prop
=
tr_params
[
'Train_prop'
]
# scale (0,1)
nclass
=
tr_params
[
'Nclass'
]
modalities
=
tr_params
[
'Num_modalities'
]
if
any
(
modalities
)
>
1
:
raise
ValueError
(
'Only 1 and 0 allowed. At least one of the modalities must be selected to be 1'
)
nchannels
=
sum
(
modalities
)
if
nclass
==
2
:
criterion
=
truenet_tumseg_loss_functions
.
CombinedLoss
()
else
:
criterion
=
truenet_tumseg_loss_functions
.
CombinedMultiLoss
(
nclasses
=
nclass
)
if
verbose
:
print
(
'Found'
+
str
(
len
(
sub_name_dicts
))
+
'subjects'
,
flush
=
True
)
num_val_subs
=
max
(
int
(
len
(
sub_name_dicts
)
*
(
1
-
train_prop
)),
1
)
train_name_dicts
,
val_name_dicts
,
val_ids
=
truenet_tumseg_utils
.
select_train_val_names
(
sub_name_dicts
,
num_val_subs
)
if
type
(
milestones
)
!=
list
:
milestones
=
[
milestones
]
models
=
[]
if
req_plane
==
'all'
or
req_plane
==
'axial'
:
model_axial
=
truenet_tumseg_model
.
TrUENetTumSeg
(
n_channels
=
nchannels
,
n_classes
=
nclass
,
init_channels
=
64
,
plane
=
'axial'
)
model_axial
.
to
(
device
=
device
)
model_axial
=
nn
.
DataParallel
(
model_axial
)
print
(
'Total number of Axial model to train: '
,
str
(
sum
([
p
.
numel
()
for
p
in
model_axial
.
parameters
()])),
flush
=
True
)
if
optim_type
==
'adam'
:
epsilon
=
tr_params
[
'Epsilon'
]
optimizer_axial
=
optim
.
Adam
(
filter
(
lambda
p
:
p
.
requires_grad
,
model_axial
.
parameters
()),
lr
=
lrt
,
eps
=
epsilon
)
elif
optim_type
==
'sgd'
:
moment
=
tr_params
[
'Momentum'
]
optimizer_axial
=
optim
.
SGD
(
filter
(
lambda
p
:
p
.
requires_grad
,
model_axial
.
parameters
()),
lr
=
lrt
,
momentum
=
moment
)
else
:
raise
ValueError
(
"Invalid optimiser choice provided! Valid options: 'adam', 'sgd'"
)
scheduler
=
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer_axial
,
milestones
,
gamma
=
gamma
,
last_epoch
=-
1
)
model_axial
=
truenet_tumseg_train
.
train_truenet
(
train_name_dicts
,
val_name_dicts
,
model_axial
,
criterion
,
optimizer_axial
,
scheduler
,
tr_params
,
device
,
mode
=
'axial'
,
augment
=
aug
,
save_checkpoint
=
save_cp
,
save_weights
=
save_wei
,
save_case
=
save_case
,
verbose
=
verbose
,
dir_checkpoint
=
dir_cp
)
models
.
append
(
model_axial
)
if
req_plane
==
'all'
or
req_plane
==
'sagittal'
:
model_sagittal
=
truenet_tumseg_model
.
TrUENetTumSeg
(
n_channels
=
nchannels
,
n_classes
=
nclass
,
init_channels
=
64
,
plane
=
'sagittal'
)
model_sagittal
.
to
(
device
=
device
)
model_sagittal
=
nn
.
DataParallel
(
model_sagittal
)
print
(
'Total number of Sagittal model to train: '
,
str
(
sum
([
p
.
numel
()
for
p
in
model_sagittal
.
parameters
()])),
flush
=
True
)
if
optim_type
==
'adam'
:
epsilon
=
tr_params
[
'Epsilon'
]
optimizer_sagittal
=
optim
.
Adam
(
filter
(
lambda
p
:
p
.
requires_grad
,
model_sagittal
.
parameters
()),
lr
=
lrt
,
eps
=
epsilon
)
elif
optim_type
==
'sgd'
:
moment
=
tr_params
[
'Momentum'
]
optimizer_sagittal
=
optim
.
SGD
(
filter
(
lambda
p
:
p
.
requires_grad
,
model_sagittal
.
parameters
()),
lr
=
lrt
,
momentum
=
moment
)
else
:
raise
ValueError
(
"Invalid optimiser choice provided! Valid options: 'adam', 'sgd'"
)
scheduler
=
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer_sagittal
,
milestones
,
gamma
=
gamma
,
last_epoch
=-
1
)
model_sagittal
=
truenet_tumseg_train
.
train_truenet
(
train_name_dicts
,
val_name_dicts
,
model_sagittal
,
criterion
,
optimizer_sagittal
,
scheduler
,
tr_params
,
device
,
mode
=
'sagittal'
,
augment
=
aug
,
save_checkpoint
=
save_cp
,
save_weights
=
save_wei
,
save_case
=
save_case
,
verbose
=
verbose
,
dir_checkpoint
=
dir_cp
)
models
.
append
(
model_sagittal
)
if
req_plane
==
'all'
or
req_plane
==
'coronal'
:
model_coronal
=
truenet_tumseg_model
.
TrUENetTumSeg
(
n_channels
=
nchannels
,
n_classes
=
nclass
,
init_channels
=
64
,
plane
=
'coronal'
)
model_coronal
.
to
(
device
=
device
)
model_coronal
=
nn
.
DataParallel
(
model_coronal
)
print
(
'Total number of Coronal model to train: '
,
str
(
sum
([
p
.
numel
()
for
p
in
model_coronal
.
parameters
()])),
flush
=
True
)
if
optim_type
==
'adam'
:
epsilon
=
tr_params
[
'Epsilon'
]
optimizer_coronal
=
optim
.
Adam
(
filter
(
lambda
p
:
p
.
requires_grad
,
model_coronal
.
parameters
()),
lr
=
lrt
,
eps
=
epsilon
)
elif
optim_type
==
'sgd'
:
moment
=
tr_params
[
'Momentum'
]
optimizer_coronal
=
optim
.
SGD
(
filter
(
lambda
p
:
p
.
requires_grad
,
model_coronal
.
parameters
()),
lr
=
lrt
,
momentum
=
moment
)
else
:
raise
ValueError
(
"Invalid optimiser choice provided! Valid options: 'adam', 'sgd'"
)
scheduler
=
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer_coronal
,
milestones
,
gamma
=
gamma
,
last_epoch
=-
1
)
model_coronal
=
truenet_tumseg_train
.
train_truenet
(
train_name_dicts
,
val_name_dicts
,
model_coronal
,
criterion
,
optimizer_coronal
,
scheduler
,
tr_params
,
device
,
mode
=
'coronal'
,
augment
=
aug
,
save_checkpoint
=
save_cp
,
save_weights
=
save_wei
,
save_case
=
save_case
,
verbose
=
verbose
,
dir_checkpoint
=
dir_cp
)
models
.
append
(
model_coronal
)
if
req_plane
==
'all'
or
req_plane
==
'tc'
:
criterion
=
truenet_tumseg_loss_functions
.
CombinedLoss
()
model_tc
=
truenet_tumseg_model
.
TrUENetTumSeg
(
n_channels
=
nchannels
,
n_classes
=
2
,
init_channels
=
64
,
plane
=
'axial'
)
model_tc
.
to
(
device
=
device
)
model_tc
=
nn
.
DataParallel
(
model_tc
)
print
(
'Total number of Tumour core model to train: '
,
str
(
sum
([
p
.
numel
()
for
p
in
model_tc
.
parameters
()])),
flush
=
True
)
if
optim_type
==
'adam'
:
epsilon
=
tr_params
[
'Epsilon'
]
optimizer_tc
=
optim
.
Adam
(
filter
(
lambda
p
:
p
.
requires_grad
,
model_tc
.
parameters
()),
lr
=
lrt
,
eps
=
epsilon
)
elif
optim_type
==
'sgd'
:
moment
=
tr_params
[
'Momentum'
]
optimizer_tc
=
optim
.
SGD
(
filter
(
lambda
p
:
p
.
requires_grad
,
model_tc
.
parameters
()),
lr
=
lrt
,
momentum
=
moment
)
else
:
raise
ValueError
(
"Invalid optimiser choice provided! Valid options: 'adam', 'sgd'"
)
scheduler
=
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer_tc
,
milestones
,
gamma
=
gamma
,
last_epoch
=-
1
)
model_tc
=
truenet_tumseg_train
.
train_truenet
(
train_name_dicts
,
val_name_dicts
,
model_tc
,
criterion
,
optimizer_tc
,
scheduler
,
tr_params
,
device
,
mode
=
'tc'
,
augment
=
aug
,
save_checkpoint
=
save_cp
,
save_weights
=
save_wei
,
save_case
=
save_case
,
verbose
=
verbose
,
dir_checkpoint
=
dir_cp
)
models
.
append
(
model_tc
)
return
models
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment