Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Vaanathi Sundaresan
truenet_tumseg
Commits
450fa91b
Commit
450fa91b
authored
May 24, 2021
by
Vaanathi Sundaresan
Browse files
general util functions
parent
2e6a2633
Changes
1
Hide whitespace changes
Inline
Side-by-side
truenet_tumseg/utils/truenet_tumseg_utils.py
0 → 100644
View file @
450fa91b
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
random
import
os
import
torch
from
skimage.measure
import
label
,
regionprops
from
scipy
import
ndimage
import
math
from
collections
import
OrderedDict
#=========================================================================================
# Truenet general utility functions
# Vaanathi Sundaresan
# 09-03-2021, Oxford
#=========================================================================================
def
select_train_val_names
(
data_path
,
val_numbers
):
'''
Select training and validation subjects randomly given th no. of validation subjects
:param data_path: input filepaths
:param val_numbers: int, number of validation subjects
:return:
'''
val_ids
=
random
.
choices
(
list
(
np
.
arange
(
len
(
data_path
))),
k
=
val_numbers
)
train_ids
=
np
.
setdiff1d
(
np
.
arange
(
len
(
data_path
)),
val_ids
)
data_path_train
=
[
data_path
[
ind
]
for
ind
in
train_ids
]
data_path_val
=
[
data_path
[
ind
]
for
ind
in
val_ids
]
return
data_path_train
,
data_path_val
,
val_ids
def
freeze_layer_for_finetuning
(
model
,
layer_to_ft
,
verbose
=
False
):
'''
Unfreezing specific layers of the model for fine-tuning
:param model: model
:param layer_to_ft: list of ints, layers to fine-tune starting from the decoder end.
:param verbose: bool, display debug messages
:return: model after unfreezing only the required layers
'''
model_layer_names
=
[
'outconv'
,
'up1'
,
'up2'
,
'up3'
,
'down3'
,
'down2'
,
'down1'
,
'convfirst'
]
model_layers_tobe_ftd
=
[]
for
layer_id
in
layer_to_ft
:
model_layers_tobe_ftd
.
append
(
model_layer_names
[
layer_id
-
1
])
for
name
,
child
in
model
.
module
.
named_children
():
if
name
in
model_layers_tobe_ftd
:
if
verbose
:
print
(
'Model parameters'
,
flush
=
True
)
print
(
name
+
' is unfrozen'
,
flush
=
True
)
for
param
in
child
.
parameters
():
param
.
requires_grad
=
True
else
:
if
verbose
:
print
(
'Model parameters'
,
flush
=
True
)
print
(
name
+
' is frozen'
,
flush
=
True
)
for
param
in
child
.
parameters
():
param
.
requires_grad
=
False
return
model
def
loading_model
(
model_name
,
model
,
mode
=
'weights'
):
if
mode
==
'weights'
:
try
:
axial_state_dict
=
torch
.
load
(
model_name
)
except
:
axial_state_dict
=
torch
.
load
(
model_name
,
map_location
=
'cpu'
)
else
:
try
:
ckpt
=
torch
.
load
(
model_name
)
except
:
ckpt
=
torch
.
load
(
model_name
,
map_location
=
'cpu'
)
axial_state_dict
=
ckpt
[
'model_state_dict'
]
new_axial_state_dict
=
OrderedDict
()
for
key
,
value
in
axial_state_dict
.
items
():
if
'module.'
in
key
[:
7
]:
name
=
key
# remove `module.`
else
:
name
=
'module.'
+
key
new_axial_state_dict
[
name
]
=
value
model
.
load_state_dict
(
new_axial_state_dict
)
return
model
def
post_processing_including_tc
(
seglab
,
tc
):
final_seg
=
np
.
zeros
(
seglab
.
shape
)
label_tc
=
(
tc
==
1
).
astype
(
int
)
label3
=
(
seglab
==
3
).
astype
(
int
)
label2
=
(
seglab
==
2
).
astype
(
int
)
label1
=
(
seglab
==
1
).
astype
(
int
)
if
np
.
sum
(
label3
)
<=
200
:
seglab
[
seglab
==
3
]
=
1
label3
=
(
seglab
==
3
).
astype
(
int
)
label1
=
(
seglab
==
1
).
astype
(
int
)
elif
np
.
sum
(
label3
)
>
2500
:
label2_tmp
=
((
label1
+
label2
)
>
0
).
astype
(
int
)
labelled3
,
nlab3
=
label
(
label3
>
0
,
return_num
=
True
)
label3_fill
=
np
.
zeros
(
labelled3
.
shape
)
for
i
in
range
(
nlab3
):
label3_tmp
=
ndimage
.
morphology
.
binary_closing
(
labelled3
==
(
i
+
1
),
structure
=
np
.
ones
((
3
,
3
,
3
)),
iterations
=
10
)
label3_fill
=
label3_fill
+
label3_tmp
.
astype
(
int
)
label3_fill
=
(
label3_fill
>
0
).
astype
(
int
)
label3_dist
=
ndimage
.
morphology
.
distance_transform_edt
(
1
-
label3_fill
)
label1
=
(
label3_dist
<
1
).
astype
(
int
)
*
label1
label2
=
label2_tmp
-
label1
labeled
,
nlab
=
label
(
label2
==
1
,
return_num
=
True
)
props
=
regionprops
(
labeled
)
areas
=
[
prop
.
area
for
prop
in
props
]
if
len
(
areas
)
>
1
:
sorted_area
=
np
.
argsort
(
np
.
array
(
areas
))
cents
=
[
prop
.
centroid
for
prop
in
props
]
x0
=
cents
[
sorted_area
[
-
1
]][
0
]
y0
=
cents
[
sorted_area
[
-
1
]][
1
]
z0
=
cents
[
sorted_area
[
-
1
]][
2
]
dists
=
[
math
.
sqrt
((
cent
[
0
]
-
x0
)
**
2
+
(
cent
[
1
]
-
y0
)
**
2
+
(
cent
[
2
]
-
z0
)
**
2
)
for
cent
in
cents
]
else
:
dists
=
0
float_ind
=
np
.
where
(
np
.
array
(
dists
)
>=
75
)[
0
]
if
float_ind
.
size
:
for
sh
in
range
(
float_ind
.
shape
[
0
]):
labeled
[
labeled
==
(
float_ind
[
sh
]
+
1
)]
=
0
label2_final
=
(
labeled
>
0
).
astype
(
int
)
label1
=
((
label1
+
label_tc
)
>
0
).
astype
(
int
)
final_seg
[
label2_final
==
1
]
=
2
final_seg
[
label1
==
1
]
=
1
final_seg
[
label3
==
1
]
=
3
final_seg
=
final_seg
.
astype
(
int
)
return
final_seg
class
EarlyStoppingModelCheckpointing
:
'''
Early stopping stops the training if the validation loss doesnt improve after a given patience
'''
def
__init__
(
self
,
patience
=
5
,
verbose
=
False
):
self
.
patience
=
patience
self
.
verbose
=
verbose
self
.
counter
=
0
self
.
best_score
=
None
self
.
early_stop
=
False
self
.
val_loss_min
=
np
.
Inf
def
__call__
(
self
,
val_loss
,
val_dice
,
best_val_dice
,
model
,
epoch
,
optimizer
,
scheduler
,
loss
,
tr_prms
,
weights
=
True
,
checkpoint
=
True
,
save_condition
=
'best'
,
model_path
=
None
,
plane
=
'axial'
):
score
=
-
val_loss
if
self
.
best_score
is
None
:
self
.
best_score
=
score
self
.
save_checkpoint
(
val_loss
,
val_dice
,
best_val_dice
,
model
,
epoch
,
optimizer
,
scheduler
,
loss
,
tr_prms
,
weights
,
checkpoint
,
save_condition
,
model_path
,
plane
)
elif
score
<
self
.
best_score
:
# Here is the criteria for activation of early stopping counter.
self
.
counter
+=
1
print
(
'Early Stopping Counter: '
,
self
.
counter
,
'/'
,
self
.
patience
)
if
self
.
counter
>=
self
.
patience
:
# When the counter reaches the patience value, early stopping flag is activated to stop the training.
self
.
early_stop
=
True
else
:
self
.
best_score
=
score
self
.
save_checkpoint
(
val_loss
,
val_dice
,
best_val_dice
,
model
,
epoch
,
optimizer
,
scheduler
,
loss
,
tr_prms
,
weights
,
checkpoint
,
save_condition
,
model_path
,
plane
)
self
.
counter
=
0
def
save_checkpoint
(
self
,
val_loss
,
val_acc
,
best_val_acc
,
model
,
epoch
,
optimizer
,
scheduler
,
loss
,
tr_prms
,
weights
,
checkpoint
,
save_condition
,
PATH
,
plane
):
# Saving checkpoints
if
checkpoint
:
# Saves the model when the validation loss decreases
if
self
.
verbose
:
print
(
'Validation loss increased; Saving model ...'
)
if
weights
:
if
save_condition
==
'best'
:
save_path
=
os
.
path
.
join
(
PATH
,
'Truenet_model_weights_bestdice_'
+
plane
+
'.pth'
)
if
val_acc
>
best_val_acc
:
torch
.
save
(
model
.
state_dict
(),
save_path
)
elif
save_condition
==
'everyN'
:
N
=
tr_prms
[
'EveryN'
]
if
(
epoch
%
N
)
==
0
:
save_path
=
os
.
path
.
join
(
PATH
,
'Truenet_model_weights_epoch'
+
str
(
epoch
)
+
'_'
+
plane
+
'.pth'
)
torch
.
save
(
model
.
state_dict
(),
save_path
)
elif
save_condition
==
'last'
:
save_path
=
os
.
path
.
join
(
PATH
,
'Truenet_model_weights_beforeES_'
+
plane
+
'.pth'
)
torch
.
save
(
model
.
state_dict
(),
save_path
)
else
:
raise
ValueError
(
"Invalid saving condition provided! Valid options: best, everyN, last"
)
else
:
if
save_condition
==
'best'
:
save_path
=
os
.
path
.
join
(
PATH
,
'Truenet_model_bestdice_'
+
plane
+
'.pth'
)
if
val_acc
>
best_val_acc
:
torch
.
save
({
'epoch'
:
epoch
,
'model_state_dict'
:
model
.
state_dict
(),
'optimizer_state_dict'
:
optimizer
.
state_dict
(),
'scheduler_stat_dict'
:
scheduler
.
state_dict
(),
'loss'
:
loss
},
save_path
)
elif
save_condition
==
'everyN'
:
N
=
tr_prms
[
'EveryN'
]
if
(
epoch
%
N
)
==
0
:
save_path
=
os
.
path
.
join
(
PATH
,
'Truenet_model_epoch'
+
str
(
epoch
)
+
'_'
+
plane
+
'.pth'
)
torch
.
save
({
'epoch'
:
epoch
,
'model_state_dict'
:
model
.
state_dict
(),
'optimizer_state_dict'
:
optimizer
.
state_dict
(),
'scheduler_stat_dict'
:
scheduler
.
state_dict
(),
'loss'
:
loss
},
save_path
)
elif
save_condition
==
'last'
:
save_path
=
os
.
path
.
join
(
PATH
,
'Truenet_model_beforeES_'
+
plane
+
'.pth'
)
torch
.
save
({
'epoch'
:
epoch
,
'model_state_dict'
:
model
.
state_dict
(),
'optimizer_state_dict'
:
optimizer
.
state_dict
(),
'scheduler_stat_dict'
:
scheduler
.
state_dict
(),
'loss'
:
loss
},
save_path
)
else
:
raise
ValueError
(
"Invalid saving condition provided! Valid options: best, everyN, last"
)
else
:
if
self
.
verbose
:
print
(
'Validation loss increased; Exiting without saving the model ...'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment