Skip to content
Snippets Groups Projects
Unverified Commit 4e12f353 authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

REF: make post-hoc adjustment more flexible

Expected keywords and group names are now not hard-coded
parent 193d85d7
No related branches found
No related tags found
1 merge request!5Resolve "Allow slice selection"
This commit is part of merge request !5. Comments created here will be created in the context of that merge request.
module AbstractTypes
import ...Variables: AbstractBlock, variables, adjustable, gradient_orientation, @defvar
import ...Variables: AbstractBlock, variables, adjust_groups, gradient_orientation, @defvar
"""
Super-type for all individual components that form an MRI sequence (i.e., RF pulse, gradient waveform, or readout event).
......@@ -110,8 +110,8 @@ It should be infinite if the component is linear.
split_timestep(comp_tuple::Tuple{<:Number, <:EventComponent}, precision::Number) = split_timestep(comp_tuple[2], precision)
adjustable(::RFPulseComponent) = :pulse
adjustable(::GradientWaveform) = :gradient
adjust_groups(p::RFPulseComponent) = [p.group, :pulse]
adjust_groups(g::GradientWaveform) = [g.group, :gradient]
gradient_orientation(gw::GradientWaveform{1}) = gw.orientation
......
module NoGradientBlocks
import StaticArrays: SVector, SMatrix
import ....Variables: VariableType, get_free_variable, adjustable, variables, @defvar
import ....Variables: VariableType, get_free_variable, adjust_groups, variables, @defvar
import ...AbstractTypes: GradientWaveform
import ..ChangingGradientBlocks: split_gradient
......@@ -35,6 +35,6 @@ function split_gradient(ngb::NoGradient, times::VariableType...)
return [NoGradient(d) for d in durations]
end
adjustable(::NoGradient) = :false
adjust_groups(::NoGradient) = Symbol[]
end
\ No newline at end of file
module InstantGradients
import StaticArrays: SVector, SMatrix
import JuMP: @constraint
import ...Variables: @defvar, VariableType, variables, get_free_variable, set_simple_constraints!, make_generic, adjust_internal, adjustable, gradient_orientation, apply_simple_constraint!
import ...Variables: @defvar, VariableType, variables, get_free_variable, set_simple_constraints!, make_generic, adjust_internal, adjust_groups, gradient_orientation, apply_simple_constraint!
import ..AbstractTypes: EventComponent, GradientWaveform
"""
......@@ -62,7 +62,7 @@ end
make_generic(ig::InstantGradient) = ig
adjustable(::InstantGradient) = :gradient
adjust_groups(g::InstantGradient) = [g.group, :gradient]
gradient_orientation(ig::InstantGradient{1}) = ig.orientation
......
......@@ -6,7 +6,7 @@ module Trapezoids
import JuMP: @constraint
import StaticArrays: SVector
import LinearAlgebra: norm
import ...Variables: variables, @defvar, scanner_constraints!, get_free_variable, set_simple_constraints!, gradient_orientation, VariableType, get_gradient, get_pulse, get_readout, adjustable, adjust_internal, apply_simple_constraint!, add_cost_function!
import ...Variables: variables, @defvar, scanner_constraints!, get_free_variable, set_simple_constraints!, gradient_orientation, VariableType, get_gradient, get_pulse, get_readout, adjust_groups, adjust_internal, apply_simple_constraint!, add_cost_function!
import ...Components: ChangingGradient, ConstantGradient, RFPulseComponent, ADC
import ...Containers: BaseBuildingBlock
......@@ -146,7 +146,7 @@ variables.δ
@defvar qvec(g::BaseTrapezoid, ::Nothing, ::Nothing) = variables.δ(g) .* variables.gradient_strength(g) .* 2π
adjustable(::BaseTrapezoid) = :gradient
adjust_groups(t::Trapezoid) = [t.group, :gradient]
function adjust_internal(trap::Trapezoid1D; orientation=nothing, scale=1., rotation=nothing)
if !isnothing(orientation) && !isnothing(rotation)
......
......@@ -3,10 +3,12 @@ Define post-fitting adjustments of the sequences
"""
module PostHoc
import ..Variables: AbstractBlock, adjust_internal, adjustable, adjust
import ..Variables: AbstractBlock, adjust_internal, adjust_groups, adjust
import ..Components: GradientWaveform, RFPulseComponent, BaseComponent, NoGradient
import ..Containers: ContainerBlock, Sequence, Wait
const UsedNamesType = Dict{Symbol, Set{Symbol}}
"""
adjust(block; kwargs...)
......@@ -34,7 +36,11 @@ To affect all gradients or pulses, use `gradient=` or `pulse`, e.g.
will divide the amplitude of all RV pulses by two.
"""
function adjust(block::AbstractBlock; merge=true, kwargs...)
used_names = Set{Symbol}()
invalid_type = Set(key for (key, value) in pairs(kwargs) if !(value isa NamedTuple))
if length(invalid_type) > 0
error("All `adjust` keywords except for merge should be a NamedTuple, like (scale=3, ). This is not the case for: $(invalid_type).")
end
used_names = UsedNamesType()
n_adjust, kwargs_list = adjust_kwargs_list(; kwargs...)
if isnothing(n_adjust)
res = adjust_helper(block, used_names; kwargs_list[1]...)
......@@ -48,11 +54,22 @@ function adjust(block::AbstractBlock; merge=true, kwargs...)
end
end
unused_names = filter(keys(kwargs)) do key
!(key in used_names)
!(key in keys(used_names))
end
if length(unused_names) > 0
@warn "Some group/type names were not used in call to `MRIBuilder.adjust`, namely: $(unused_names)."
end
for group_name in keys(kwargs)
if group_name in unused_names
continue
end
unused_keys = filter(keys(kwargs[group_name])) do key
!(key in used_names[group_name])
end
if length(unused_keys) > 0
@warn "Some keywords provided for group `$(group_name)` were not used, namely: $(unused_keys)."
end
end
res
end
......@@ -90,33 +107,35 @@ function adjust_kwargs_list(; kwargs...)
return (n_adjust, kwargs_list)
end
function adjust_helper(block::AbstractBlock, used_names::Set{Symbol}; gradient=(), pulse=(), kwargs...)
function adjust_helper(block::AbstractBlock, used_names::UsedNamesType; kwargs...)
params = []
adjust_type = adjustable(block)
if adjust_type != :false && (!isnothing(block.group) && (block.group in keys(kwargs)))
push!(used_names, block.group)
new_block = adjust_internal(block; kwargs[block.group]...)
elseif adjust_type == :gradient
push!(used_names, :gradient)
new_block = adjust_internal(block; gradient...)
elseif adjust_type == :pulse
push!(used_names, :pulse)
new_block = adjust_internal(block; pulse...)
else
new_block = block
for prop_name in propertynames(block)
push!(params, adjust_helper(getproperty(block, prop_name), used_names; kwargs...))
end
new_block = typeof(block)(params...)
for group in adjust_groups(new_block)
if group in keys(kwargs)
if !(group in keys(used_names))
used_names[group] = Set{Symbol}()
end
all_available_kwargs = kwargs[group]
use_kwargs = reduce(vcat, Base.kwarg_decl.(methods(adjust_internal, (typeof(new_block), ))))
@assert length(use_kwargs) > 0 "Invalid definition of `internal_kwargs` for $(typeof(new_block))"
internal_kwargs = Dict(key => value for (key, value) in pairs(all_available_kwargs) if key in use_kwargs)
union!(used_names[group], keys(internal_kwargs))
return adjust_internal(block; internal_kwargs...)
end
for prop_name in propertynames(new_block)
push!(params, adjust_helper(getproperty(new_block, prop_name), used_names; gradient=gradient, pulse=pulse, kwargs...))
end
return typeof(block)(params...)
return new_block
end
adjust_helper(some_value, used_names::Set{Symbol}; kwargs...) = some_value
adjust_helper(array_variable::AbstractArray, used_names::Set{Symbol}; kwargs...) = map(array_variable) do v adjust_helper(v, used_names; kwargs...) end
adjust_helper(dict_variable::AbstractDict, used_names::Set{Symbol}; kwargs...) = typeof(dict_variable)(k => adjust_helper(v, used_names; kwargs...) for (k, v) in pairs(dict_variable))
adjust_helper(tuple_variable::Tuple, used_names::Set{Symbol}; kwargs...) = map(tuple_variable) do v adjust_helper(v, used_names; kwargs...) end
adjust_helper(pair:: Pair, used_names::Set{Symbol}; kwargs...) = adjust_helper(pair[1], used_names; kwargs...) => adjust_helper(pair[2], used_names; kwargs...)
adjust_helper(some_value, used_names::UsedNamesType; kwargs...) = some_value
adjust_helper(array_variable::AbstractArray, used_names::UsedNamesType; kwargs...) = map(array_variable) do v adjust_helper(v, used_names; kwargs...) end
adjust_helper(dict_variable::AbstractDict, used_names::UsedNamesType; kwargs...) = typeof(dict_variable)(k => adjust_helper(v, used_names; kwargs...) for (k, v) in pairs(dict_variable))
adjust_helper(tuple_variable::Tuple, used_names::UsedNamesType; kwargs...) = map(tuple_variable) do v adjust_helper(v, used_names; kwargs...) end
adjust_helper(pair:: Pair, used_names::UsedNamesType; kwargs...) = adjust_helper(pair[1], used_names; kwargs...) => adjust_helper(pair[2], used_names; kwargs...)
"""
......
......@@ -37,21 +37,23 @@ end
Returns the adjusted blocks and add any keywords used in the process to `names_used`.
This is a helper function used by `adjust`.
This is a helper function used by [`adjust`](@ref).
It should be defined for any block that is adjustable (as defined by [`adjust_groups`](@ref)).
"""
function adjust_internal end
"""
adjustable(block)
adjust_groups(block)
Returns whether a sequence, building block, or component can be adjusted
Returns an array of keywords in [`adjust`](@ref) that should affect a specfic block.
Can return one of:
- `:false`: not adjustable
If any of these keywords are present in [`adjust`](@ref), then [`adjust_internal`](@ref) will be called.
Some standard keywords are:
- `:gradient`: expects gradient adjustment parameters
- `:pulse`: expects RF pulse adjustment parameters
"""
adjustable(::AbstractBlock) = :false
adjust_groups(::AbstractBlock) = Symbol[]
# further defined in post_hoc.jl
function adjust end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment