Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Michiel Cottaar
gyral_structure
Commits
42b38735
Commit
42b38735
authored
May 03, 2020
by
Michiel Cottaar
Browse files
BUG: fix extent of param_evaluator and run on smaller npos
parent
cdccc8b9
Pipeline
#5298
failed with stage
in 10 minutes and 17 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
gyral_structure/basis/core.py
View file @
42b38735
...
...
@@ -325,18 +325,19 @@ class BasisFunc(object):
with
algorithm
.
set
(
method
=
method
,
override
=
True
):
return
MultEvaluator
(
self
,
req
)
def
param_evaluator
(
self
,
parameters
,
method
=
None
,
nsim
=
2
**
12
):
def
param_evaluator
(
self
,
parameters
,
method
=
None
,
extent
=
0
,
nsim
=
2
**
12
):
"""Returns a function that evaluates the field at a given location
Used for streamline evaluation.
:param parameters: set of parameters for which the field should be evaluated at different positions
:param method: algorithm to use when computing field
:param nsim: number of positions to evaluate simultaneously
:param extent: maximum extent of request to accept (keep at zero for tractography)
:param nsim: maximum number of positions to evaluate simultaneously
:return: function to map the positions to the field
"""
from
.evaluator
import
MultEvaluator
self
.
precompute_evaluator
()
self
.
precompute_evaluator
(
extent
)
with
algorithm
.
set
(
store_matrix
=
False
,
method
=
method
):
sim_evaluator
=
MultEvaluator
(
self
,
request
.
PositionRequest
(
np
.
zeros
((
nsim
,
self
.
ndim
))))
if
sim_evaluator
.
use_mat
:
...
...
@@ -362,7 +363,8 @@ class BasisFunc(object):
for
idx
in
range
(
0
,
flat_all_pos
.
shape
[
0
],
nsim
):
set_pos
=
flat_all_pos
[
idx
:
idx
+
nsim
,
:]
sim_evaluator
.
update_pos
(
set_pos
)
part_res
=
sim_evaluator
(
parameters
,
inverse
=
False
)[:
set_pos
.
shape
[
0
],
:]
part_res
=
sim_evaluator
(
parameters
,
inverse
=
False
)
assert
part_res
.
shape
==
set_pos
.
shape
if
sim_evaluator
.
use_cuda
and
hasattr
(
part_res
,
'get'
):
part_res
=
part_res
.
get
()
flat_res
[
idx
:
idx
+
nsim
]
=
part_res
...
...
@@ -521,22 +523,24 @@ class SumBase(UserList):
"""
self
.
_fixed
=
{}
def
param_evaluator
(
self
,
parameters
,
method
=
None
):
def
param_evaluator
(
self
,
parameters
,
method
=
None
,
extent
=
0
):
"""Returns a function that evaluates the basis function for arbitrary positions given a fixed parameter array
Used to __call__ streamlines
:param parameters: (nparams, ) array defining the parameters for which the field will be evaluated
:param method: Algorithm used to __call__ the basis functions
:param extent: maximum extent of request to accept (keep at zero for tractography)
:return: function that maps positions to vector field
"""
funcs
=
[]
idx_param
=
0
for
idx_elem
,
elem
in
enumerate
(
self
):
if
idx_elem
in
self
.
_fixed
:
funcs
.
append
(
elem
.
param_evaluator
(
self
.
_fixed
[
idx_elem
],
method
=
method
))
funcs
.
append
(
elem
.
param_evaluator
(
self
.
_fixed
[
idx_elem
],
extent
=
extent
,
method
=
method
))
else
:
funcs
.
append
(
elem
.
param_evaluator
(
parameters
[
idx_param
:
idx_param
+
elem
.
nparams
],
method
=
method
))
funcs
.
append
(
elem
.
param_evaluator
(
parameters
[
idx_param
:
idx_param
+
elem
.
nparams
],
extent
=
extent
,
method
=
method
))
idx_param
+=
elem
.
nparams
def
evaluate
(
positions
):
...
...
gyral_structure/basis/evaluator.py
View file @
42b38735
...
...
@@ -163,7 +163,7 @@ class RequestEvaluator(object):
"""If True evaluates the field on the GPU rather than CPU"""
return
self
.
method
in
(
Algorithm
.
cuda
,
Algorithm
.
matrix_cuda
)
def
update_pos
(
self
,
new_
positions
):
def
update_pos
(
self
,
new_
request
):
raise
ValueError
(
f
"Updating positions not implemented for
{
type
(
self
)
}
"
)
...
...
@@ -199,8 +199,8 @@ class IdentityEvaluator(RequestEvaluator):
"""
return
params
def
update_pos
(
self
,
new_
positions
):
self
.
request
=
request
.
PositionRequest
(
new_positions
)
def
update_pos
(
self
,
new_
request
):
self
.
request
=
new_
request
class
FuncRequestEvaluator
(
RequestEvaluator
):
...
...
@@ -327,13 +327,13 @@ class FuncRequestEvaluator(RequestEvaluator):
"""
self
.
results
=
{}
def
update_pos
(
self
,
new_
positions
):
self
.
request
=
request
.
PositionRequest
(
new_positions
)
def
update_pos
(
self
,
new_
request
):
self
.
request
=
new_
request
for
_
,
partial_func
in
self
.
partial_func
:
if
not
hasattr
(
partial_func
,
'update_pos'
):
self
.
partial_func
=
[(
1
,
self
.
basis
.
get_func
(
new_positions
,
self
.
method
))]
self
.
partial_func
=
[(
1
,
self
.
basis
.
get_func
(
new_
request
.
positions
,
self
.
method
))]
break
partial_func
.
update_pos
(
new_positions
)
partial_func
.
update_pos
(
new_
request
.
positions
)
class
MatRequestEvaluator
(
RequestEvaluator
):
...
...
@@ -417,8 +417,8 @@ class MatRequestEvaluator(RequestEvaluator):
def
wrap_qp
(
self
,
qp
):
return
self
.
request
.
wrap_qp
(
qp
,
{
self
.
request
:
self
.
mat
})
def
update_pos
(
self
,
new_
positions
):
self
.
request
=
request
.
PositionRequest
(
new_positions
)
def
update_pos
(
self
,
new_
request
):
self
.
request
=
new_
request
self
.
mat
=
self
.
basis
.
get_full_mat
(
self
.
request
)
if
self
.
use_cuda
:
...
...
@@ -636,10 +636,15 @@ class MultEvaluator(object):
def
update_pos
(
self
,
new_positions
):
if
len
(
self
.
request_list
)
>
1
:
raise
ValueError
(
"Can't update the positions of multiple requests"
)
new_request
=
request
.
PositionRequest
(
new_positions
)
for
evaluator
in
self
.
evaluators
.
flat
:
evaluator
.
update_pos
(
new_
positions
)
evaluator
.
update_pos
(
new_
request
)
for
evaluator
in
self
.
fixed_field_evaluators
[
self
.
request_list
[
0
]]:
evaluator
.
update_pos
(
new_positions
)
evaluator
.
update_pos
(
new_request
)
self
.
fixed_field_evaluators
[
new_request
]
=
self
.
fixed_field_evaluators
[
self
.
request_list
[
0
]]
del
self
.
fixed_field_evaluators
[
self
.
request_list
[
0
]]
self
.
request_list
[
0
]
=
new_request
self
.
full_request
=
new_request
if
self
.
fixed_field
:
self
.
fixed_field
=
{
req
:
sp
.
sum
([
evaluator
(
params
)
for
evaluator
,
params
in
self
.
fixed_field_evaluators
[
req
]],
0
)
...
...
gyral_structure/basis/radial.py
View file @
42b38735
...
...
@@ -149,16 +149,16 @@ class RadialBasis(BasisFunc):
:return: tuple with the request and centroid indices in compressed format
"""
if
self
.
_precomputed_grids
is
not
None
:
print
(
'check grid'
,
req
.
radius
(),
self
.
_precomputed_grids
[
0
])
print
(
req
)
if
self
.
_precomputed_grids
is
not
None
and
req
.
radius
()
<=
self
.
_precomputed_grids
[
0
]:
if
not
hasattr
(
self
,
'_ref_list_of_lists'
):
self
.
_ref_list_of_lists
=
np
.
zeros
(
req
.
npos
,
dtype
=
'object'
)
empty_arr
=
np
.
zeros
(
0
,
dtype
=
'i4'
)
for
idx
in
range
(
req
.
npos
):
self
.
_ref_list_of_lists
[
idx
]
=
empty_arr
list_of_lists
=
self
.
_ref_list_of_lists
.
copy
()
max_size
,
affine
,
intersects
=
self
.
_precomputed_grids
if
(
req
.
radius
()
>
max_size
).
any
():
raise
ValueError
(
"Precomputed results only deal with maximum request radius of {}, "
.
format
(
max_size
)
+
"but request of {} was found"
.
format
(
req
.
radius
().
max
()))
list_of_lists
=
self
.
_ref_list_of_lists
[:
req
.
npos
].
copy
()
_
,
affine
,
intersects
=
self
.
_precomputed_grids
voxels
=
np
.
floor
(
affine
[:
3
,
:
3
].
dot
(
req
.
center
().
T
).
T
+
affine
[:
-
1
,
-
1
]).
astype
(
'i4'
)
use
=
(
voxels
>=
0
).
all
(
-
1
)
&
(
voxels
<
intersects
.
shape
).
all
(
-
1
)
list_of_lists
[
use
]
=
intersects
[
tuple
(
voxels
[
use
].
T
)]
...
...
gyral_structure/test/test_sumbase.py
View file @
42b38735
...
...
@@ -30,7 +30,7 @@ def test_sumbase():
pfull
=
sp
.
concatenate
(
params
)
field
=
sp
.
randn
(
*
pos
.
shape
)
values
=
[
bf
.
get_evaluator
(
req
)(
pf
)
for
bf
,
pf
in
zip
(
basis_list
,
params
)]
print
(
req
,
sp
.
sum
(
values
,
0
)
-
sb1
.
get_evaluator
(
req
)(
pfull
)
)
print
(
'
req
uest'
,
req
)
assert
(
sp
.
sum
(
values
,
0
)
==
sb1
.
get_evaluator
(
req
)(
pfull
)).
all
()
values_rev
=
sp
.
concatenate
([
bf
.
get_evaluator
(
req
)(
field
,
inverse
=
True
)
for
bf
in
basis_list
])
assert
values_rev
.
shape
==
pfull
.
shape
...
...
@@ -38,9 +38,10 @@ def test_sumbase():
sb1
.
fix
(
1
,
params
[
1
])
ppart
=
sp
.
concatenate
((
params
[
0
],
params
[
2
]))
print
(
sp
.
sum
(
values
,
0
)
-
sb1
.
get_evaluator
(
req
)(
ppart
))
assert
(
sp
.
sum
(
values
,
0
)
-
sb1
.
get_evaluator
(
req
)(
ppart
)).
max
()
<
1e-5
assert
(
sp
.
sum
(
values
,
0
)
-
sb1
.
param_evaluator
(
ppart
)(
req
)).
max
()
<
1e-5
assert
(
sp
.
sum
(
values
,
0
)
-
sb1
.
param_evaluator
(
ppart
,
extent
=
req
.
radius
()
if
isinstance
(
req
,
request
.
FieldRequest
)
else
0
)(
req
)).
max
()
<
1e-5
for
b
in
basis_list
:
if
hasattr
(
b
,
'_precomputed_grid'
):
b
.
_precomputed_grid
=
None
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment