Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Vaanathi Sundaresan
truenet_tumseg
Commits
bab5eacf
Commit
bab5eacf
authored
May 24, 2021
by
Vaanathi Sundaresan
Browse files
test function
parent
9382f81a
Changes
1
Hide whitespace changes
Inline
Side-by-side
truenet_tumseg/truenet_tumorseg/truenet_tumseg_test_function.py
0 → 100644
View file @
bab5eacf
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
os
import
nibabel
as
nib
from
truenet_tumseg.truenet_tumorseg
import
(
truenet_tumseg_model
,
truenet_tumseg_evaluate
,
truenet_tumseg_data_postprocessing
)
from
truenet_tumseg.utils
import
truenet_tumseg_utils
################################################################################################
# Truenet main test function
# Vaanathi Sundaresan
# 09-03-2021, Oxford
################################################################################################
def
main
(
sub_name_dicts
,
eval_params
,
intermediate
=
False
,
model_dir
=
None
,
load_case
=
'last'
,
output_dir
=
None
,
verbose
=
False
):
'''
The main function for testing Truenet
:param sub_name_dicts: list of dictionaries containing subject filepaths
:param eval_params: dictionary of evaluation parameters
:param intermediate: bool, whether to save intermediate results
:param model_dir: str, filepath containing the test model
:param load_case: str, condition for loading the checkpoint
:param output_dir: str, filepath for saving the output predictions
:param verbose: bool, display debug messages
'''
assert
len
(
sub_name_dicts
)
>
0
,
"There must be at least 1 subject for testing."
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
nclass
=
eval_params
[
'Nclass'
]
pretrained
=
eval_params
[
'Pretrained'
]
postproc
=
eval_params
[
'Postproc'
]
modalities
=
eval_params
[
'Num_modalities'
]
if
any
(
modalities
)
>
1
:
raise
ValueError
(
'Only 1 and 0 allowed. At least one of the modalities must be selected to be 1'
)
nchannels
=
sum
(
modalities
)
model_axial
=
truenet_tumseg_model
.
TrUENetTumSeg
(
n_channels
=
nchannels
,
n_classes
=
nclass
,
init_channels
=
64
,
plane
=
'axial'
)
model_sagittal
=
truenet_tumseg_model
.
TrUENetTumSeg
(
n_channels
=
nchannels
,
n_classes
=
nclass
,
init_channels
=
64
,
plane
=
'sagittal'
)
model_coronal
=
truenet_tumseg_model
.
TrUENetTumSeg
(
n_channels
=
nchannels
,
n_classes
=
nclass
,
init_channels
=
64
,
plane
=
'coronal'
)
model_axial
.
to
(
device
=
device
)
model_sagittal
.
to
(
device
=
device
)
model_coronal
.
to
(
device
=
device
)
model_axial
=
nn
.
DataParallel
(
model_axial
)
model_sagittal
=
nn
.
DataParallel
(
model_sagittal
)
model_coronal
=
nn
.
DataParallel
(
model_coronal
)
if
postproc
is
True
:
model_tc
=
truenet_tumseg_model
.
TrUENetTumSeg
(
n_channels
=
nchannels
,
n_classes
=
2
,
init_channels
=
64
,
plane
=
'tc'
)
model_tc
.
to
(
device
=
device
)
model_tc
=
nn
.
DataParallel
(
model_tc
)
if
pretrained
:
if
load_case
==
'last'
:
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_beforeES_axial.pth'
)
model_axial
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_axial
,
mode
=
'full_model'
)
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_beforeES_sagittal.pth'
)
model_sagittal
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_sagittal
,
mode
=
'full_model'
)
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_beforeES_coronal.pth'
)
model_coronal
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_coronal
,
mode
=
'full_model'
)
if
postproc
is
True
:
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_beforeES_tc.pth'
)
model_tc
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_tc
,
mode
=
'full_model'
)
elif
load_case
==
'best'
:
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_bestdice_axial.pth'
)
model_axial
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_axial
,
mode
=
'full_model'
)
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_bestdice_sagittal.pth'
)
model_sagittal
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_sagittal
,
mode
=
'full_model'
)
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_bestdice_coronal.pth'
)
model_coronal
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_coronal
,
mode
=
'full_model'
)
if
postproc
is
True
:
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_bestdice_tc.pth'
)
model_tc
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_tc
,
mode
=
'full_model'
)
elif
load_case
==
'everyN'
:
cpn
=
eval_params
[
'EveryN'
]
try
:
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_epoch'
+
str
(
cpn
)
+
'_axial.pth'
)
model_axial
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_axial
,
mode
=
'full_model'
)
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_epoch'
+
str
(
cpn
)
+
'_sagittal.pth'
)
model_sagittal
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_sagittal
,
mode
=
'full_model'
)
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_epoch'
+
str
(
cpn
)
+
'_coronal.pth'
)
model_coronal
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_coronal
,
mode
=
'full_model'
)
if
postproc
is
True
:
model_path
=
os
.
path
.
join
(
model_dir
,
'Truenet_model_epoch'
+
str
(
cpn
)
+
'_tc.pth'
)
model_tc
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_tc
,
mode
=
'full_model'
)
except
ImportError
:
raise
ImportError
(
'Incorrect N value provided for the available pretrained models'
)
else
:
raise
ValueError
(
"Invalid saving condition provided! Valid options: best, specific, last"
)
else
:
model_name
=
eval_params
[
'Modelname'
]
try
:
model_path
=
os
.
path
.
join
(
model_dir
,
model_name
+
'_axial.pth'
)
model_axial
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_axial
)
model_path
=
os
.
path
.
join
(
model_dir
,
model_name
+
'_sagittal.pth'
)
model_sagittal
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_sagittal
)
model_path
=
os
.
path
.
join
(
model_dir
,
model_name
+
'_coronal.pth'
)
model_coronal
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_coronal
)
if
postproc
is
True
:
model_path
=
os
.
path
.
join
(
model_dir
,
model_name
+
'_tc.pth'
)
model_tc
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_tc
)
except
:
try
:
model_path
=
os
.
path
.
join
(
model_dir
,
model_name
+
'_axial.pth'
)
model_axial
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_axial
,
mode
=
'full_model'
)
model_path
=
os
.
path
.
join
(
model_dir
,
model_name
+
'_sagittal.pth'
)
model_sagittal
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_sagittal
,
mode
=
'full_model'
)
model_path
=
os
.
path
.
join
(
model_dir
,
model_name
+
'_coronal.pth'
)
model_coronal
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_coronal
,
mode
=
'full_model'
)
if
postproc
is
True
:
model_path
=
os
.
path
.
join
(
model_dir
,
model_name
+
'_tc.pth'
)
model_tc
=
truenet_tumseg_utils
.
loading_model
(
model_path
,
model_tc
,
mode
=
'full_model'
)
except
ImportError
:
if
postproc
is
True
:
raise
ImportError
(
'In directory '
+
model_dir
+
', '
+
model_name
+
'_axial.pth or '
+
model_name
+
'_sagittal.pth or '
+
model_name
+
'_coronal.pth or '
+
model_name
+
'_tc.pth does not appear to be a valid model file'
)
else
:
raise
ImportError
(
'In directory '
+
model_dir
+
', '
+
model_name
+
'_axial.pth or'
+
model_name
+
'_sagittal.pth or'
+
model_name
+
'_coronal.pth '
+
'does not appear to be a valid model file'
)
if
verbose
:
print
(
'Found'
+
str
(
len
(
sub_name_dicts
))
+
'subjects'
,
flush
=
True
)
for
sub
in
range
(
len
(
sub_name_dicts
)):
if
verbose
:
print
(
'Predicting output for subject '
+
str
(
sub
+
1
)
+
'...'
,
flush
=
True
)
test_sub_dict
=
[
sub_name_dicts
[
sub
]]
basename
=
test_sub_dict
[
0
][
'basename'
]
probs_combined
=
[]
flair_path
=
test_sub_dict
[
0
][
'flair_path'
]
flair_hdr
=
nib
.
load
(
flair_path
).
header
probs_axial
=
truenet_tumseg_evaluate
.
evaluate_truenet
(
test_sub_dict
,
model_axial
,
eval_params
,
device
,
mode
=
'axial'
,
verbose
=
verbose
)
probs_axial
=
truenet_tumseg_data_postprocessing
.
resize_to_original_size
(
probs_axial
,
test_sub_dict
,
plane
=
'axial'
)
probs_combined
.
append
(
probs_axial
)
if
intermediate
:
save_path
=
os
.
path
.
join
(
output_dir
,
'Predicted_probmap_truenet_'
+
basename
+
'_axial.nii.gz'
)
preds_axial
=
truenet_tumseg_data_postprocessing
.
get_final_3dvolumes
(
np
.
argmax
(
probs_axial
,
axis
=
3
),
test_sub_dict
)
if
verbose
:
print
(
'Saving the intermediate Axial prediction ...'
,
flush
=
True
)
newhdr
=
flair_hdr
.
copy
()
newobj
=
nib
.
nifti1
.
Nifti1Image
(
preds_axial
,
None
,
header
=
newhdr
)
nib
.
save
(
newobj
,
save_path
)
probs_sagittal
=
truenet_tumseg_evaluate
.
evaluate_truenet
(
test_sub_dict
,
model_sagittal
,
eval_params
,
device
,
mode
=
'sagittal'
,
verbose
=
verbose
)
probs_sagittal
=
truenet_tumseg_data_postprocessing
.
resize_to_original_size
(
probs_sagittal
,
test_sub_dict
,
plane
=
'sagittal'
)
probs_combined
.
append
(
probs_sagittal
)
if
intermediate
:
save_path
=
os
.
path
.
join
(
output_dir
,
'Predicted_probmap_truenet_'
+
basename
+
'_sagittal.nii.gz'
)
preds_sagittal
=
truenet_tumseg_data_postprocessing
.
get_final_3dvolumes
(
np
.
argmax
(
probs_sagittal
,
axis
=
3
),
test_sub_dict
)
if
verbose
:
print
(
'Saving the intermediate Sagittal prediction ...'
,
flush
=
True
)
newhdr
=
flair_hdr
.
copy
()
newobj
=
nib
.
nifti1
.
Nifti1Image
(
preds_sagittal
,
None
,
header
=
newhdr
)
nib
.
save
(
newobj
,
save_path
)
probs_coronal
=
truenet_tumseg_evaluate
.
evaluate_truenet
(
test_sub_dict
,
model_coronal
,
eval_params
,
device
,
mode
=
'coronal'
,
verbose
=
verbose
)
probs_coronal
=
truenet_tumseg_data_postprocessing
.
resize_to_original_size
(
probs_coronal
,
test_sub_dict
,
plane
=
'coronal'
)
probs_combined
.
append
(
probs_coronal
)
if
intermediate
:
save_path
=
os
.
path
.
join
(
output_dir
,
'Predicted_probmap_truenet_'
+
basename
+
'_coronal.nii.gz'
)
preds_coronal
=
truenet_tumseg_data_postprocessing
.
get_final_3dvolumes
(
np
.
argmax
(
probs_coronal
,
axis
=
3
),
test_sub_dict
)
if
verbose
:
print
(
'Saving the intermediate Coronal prediction ...'
,
flush
=
True
)
newhdr
=
flair_hdr
.
copy
()
newobj
=
nib
.
nifti1
.
Nifti1Image
(
preds_coronal
,
None
,
header
=
newhdr
)
nib
.
save
(
newobj
,
save_path
)
probs_combined
=
np
.
array
(
probs_combined
)
prob_mean
=
np
.
mean
(
probs_combined
,
axis
=
0
)
# pred_mean = truenet_tumseg_data_postprocessing.get_final_3dvolumes(prob_mean, test_sub_dict)
pred_output
=
np
.
argmax
(
prob_mean
,
axis
=
3
)
pred_output
=
truenet_tumseg_data_postprocessing
.
get_final_3dvolumes
(
pred_output
,
test_sub_dict
)
if
postproc
is
False
:
# save_path = os.path.join(output_dir, 'Predicted_probmap_truenet_' + basename + '_final.nii.gz')
# if verbose:
# print('Saving the final prediction ...', flush=True)
#
# newhdr = flair_hdr.copy()
# newobj = nib.nifti1.Nifti1Image(pred_mean, None, header=newhdr)
# nib.save(newobj, save_path)
save_path
=
os
.
path
.
join
(
output_dir
,
'Predicted_output_truenet_'
+
basename
+
'_final.nii.gz'
)
if
verbose
:
print
(
'Saving the final output segmented map ...'
,
flush
=
True
)
newhdr
=
flair_hdr
.
copy
()
newobj
=
nib
.
nifti1
.
Nifti1Image
(
pred_output
,
None
,
header
=
newhdr
)
nib
.
save
(
newobj
,
save_path
)
else
:
probs_tc
=
truenet_tumseg_evaluate
.
evaluate_truenet
(
test_sub_dict
,
model_tc
,
eval_params
,
device
,
mode
=
'tc'
,
verbose
=
verbose
)
probs_tc
=
truenet_tumseg_data_postprocessing
.
resize_to_original_size
(
probs_tc
,
test_sub_dict
,
plane
=
'tc'
)
preds_tc
=
truenet_tumseg_data_postprocessing
.
get_final_3dvolumes
(
np
.
argmax
(
probs_tc
,
axis
=
3
),
test_sub_dict
)
pred_final
=
truenet_tumseg_utils
.
post_processing_including_tc
(
pred_output
,
preds_tc
)
if
intermediate
:
save_path
=
os
.
path
.
join
(
output_dir
,
'Predicted_probmap_truenet_'
+
basename
+
'_tc.nii.gz'
)
if
verbose
:
print
(
'Saving the intermediate Tumour core prediction ...'
,
flush
=
True
)
newhdr
=
flair_hdr
.
copy
()
newobj
=
nib
.
nifti1
.
Nifti1Image
(
preds_tc
,
None
,
header
=
newhdr
)
nib
.
save
(
newobj
,
save_path
)
save_path
=
os
.
path
.
join
(
output_dir
,
'Predicted_output_truenet_'
+
basename
+
'_final.nii.gz'
)
if
verbose
:
print
(
'Saving the final output segmented map ...'
,
flush
=
True
)
newhdr
=
flair_hdr
.
copy
()
newobj
=
nib
.
nifti1
.
Nifti1Image
(
pred_final
,
None
,
header
=
newhdr
)
nib
.
save
(
newobj
,
save_path
)
if
verbose
:
print
(
'Testing complete for all subjects!'
,
flush
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment