Skip to content
Snippets Groups Projects
Verified Commit 79bd6d63 authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

Use more elegant encapsulation for variables

parent ba0a832e
No related branches found
No related tags found
No related merge requests found
...@@ -4,7 +4,7 @@ import LinearAlgebra: norm ...@@ -4,7 +4,7 @@ import LinearAlgebra: norm
import StaticArrays: SVector import StaticArrays: SVector
import JuMP: @constraint, @objective, objective_function import JuMP: @constraint, @objective, objective_function
import ...BuildSequences: global_model, global_scanner import ...BuildSequences: global_model, global_scanner
import ...Variables: VariableType, duration, rise_time, flat_time, effective_time, qvec, gradient_strength, slew_rate, inverse_slice_thickness, get_free_variable import ...Variables: VariableType, duration, rise_time, flat_time, effective_time, qvec, gradient_strength, slew_rate, inverse_slice_thickness, get_free_variable, get_pulse
import ...Components: ChangingGradient, ConstantGradient, RFPulseComponent import ...Components: ChangingGradient, ConstantGradient, RFPulseComponent
import ..BaseBuildingBlocks: BaseBuildingBlock import ..BaseBuildingBlocks: BaseBuildingBlock
...@@ -92,6 +92,7 @@ duration(spoilt::SpoiltSliceSelect) = sum(rise_time(spoilt)) + sum(flat_time(spo ...@@ -92,6 +92,7 @@ duration(spoilt::SpoiltSliceSelect) = sum(rise_time(spoilt)) + sum(flat_time(spo
slew_rate(spoilt::SpoiltSliceSelect) = spoilt.slew_rate slew_rate(spoilt::SpoiltSliceSelect) = spoilt.slew_rate
inverse_slice_thickness(spoilt::SpoiltSliceSelect) = spoilt.slew_rate * spoilt.diff_time * duration(spoilt.pulse) * 1e3 inverse_slice_thickness(spoilt::SpoiltSliceSelect) = spoilt.slew_rate * spoilt.diff_time * duration(spoilt.pulse) * 1e3
gradient_strength(spoilt::SpoiltSliceSelect) = slew_rate(spoilt) * max(spoilt.rise_time1, spoilt.fall_time2) gradient_strength(spoilt::SpoiltSliceSelect) = slew_rate(spoilt) * max(spoilt.rise_time1, spoilt.fall_time2)
get_pulse(spoilt::SpoiltSliceSelect) = spoilt.pulse
function all_gradient_strengths(spoilt::SpoiltSliceSelect) function all_gradient_strengths(spoilt::SpoiltSliceSelect)
grad1 = spoilt.slew_rate * rise_time(spoilt)[1] grad1 = spoilt.slew_rate * rise_time(spoilt)[1]
grad2 = grad1 - spoilt.slew_rate * flat_time(spoilt)[1] grad2 = grad1 - spoilt.slew_rate * flat_time(spoilt)[1]
......
...@@ -7,7 +7,7 @@ import JuMP: @constraint ...@@ -7,7 +7,7 @@ import JuMP: @constraint
import StaticArrays: SVector import StaticArrays: SVector
import LinearAlgebra: norm import LinearAlgebra: norm
import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType, inverse_bandwidth, effective_time, qval_square, duration, set_simple_constraints!, scanner_constraints!, inverse_slice_thickness import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType, inverse_bandwidth, effective_time, qval_square, duration, set_simple_constraints!, scanner_constraints!, inverse_slice_thickness
import ...Variables: Variables, all_variables_symbols, dwell_time, inverse_fov, inverse_voxel_size, fov, voxel_size import ...Variables: Variables, all_variables_symbols, dwell_time, inverse_fov, inverse_voxel_size, fov, voxel_size, get_gradient, get_pulse, get_readout
import ...BuildSequences: global_model import ...BuildSequences: global_model
import ...Components: ChangingGradient, ConstantGradient, RFPulseComponent, ADC import ...Components: ChangingGradient, ConstantGradient, RFPulseComponent, ADC
import ..BaseBuildingBlocks: BaseBuildingBlock import ..BaseBuildingBlocks: BaseBuildingBlock
...@@ -108,10 +108,6 @@ slew_rate(g::Trapezoid) = g.slew_rate ...@@ -108,10 +108,6 @@ slew_rate(g::Trapezoid) = g.slew_rate
δ(g::Trapezoid) = rise_time(g) + flat_time(g) δ(g::Trapezoid) = rise_time(g) + flat_time(g)
duration(g::Trapezoid) = 2 * rise_time(g) + flat_time(g) duration(g::Trapezoid) = 2 * rise_time(g) + flat_time(g)
for func in (:rise_time, :flat_time, :gradient_strength, :slew_rate, :δ, :duration, :qvec)
@eval $func(bt::BaseTrapezoid) = $func(bt.trapezoid)
end
qvec(g::BaseTrapezoid, ::Nothing, ::Nothing) = δ(g) .* gradient_strength(g) .* 2π qvec(g::BaseTrapezoid, ::Nothing, ::Nothing) = δ(g) .* gradient_strength(g) .* 2π
""" """
...@@ -146,12 +142,8 @@ Base.getindex(pg::SliceSelect, ::Val{:pulse}) = pg.pulse ...@@ -146,12 +142,8 @@ Base.getindex(pg::SliceSelect, ::Val{:pulse}) = pg.pulse
inverse_slice_thickness(ss::SliceSelect) = 1e3 * gradient_strength(ss.trapezoid) .* inverse_bandwidth(ss.pulse) inverse_slice_thickness(ss::SliceSelect) = 1e3 * gradient_strength(ss.trapezoid) .* inverse_bandwidth(ss.pulse)
for func in all_variables_symbols[:pulse] get_pulse(ss::SliceSelect) = ss.pulse
if func in (:inverse_slice_thickness, :slice_thickness) get_gradient(ss::SliceSelect) = ss.trapezoid
continue
end
Variables.$func(ss::SliceSelect) = Variables.$func(ss.pulse)
end
""" """
LineReadout(adc; ramp_overlap=1., orientation=nothing, group=nothing, variables...) LineReadout(adc; ramp_overlap=1., orientation=nothing, group=nothing, variables...)
...@@ -188,11 +180,7 @@ ramp_overlap(lr::LineReadout) = lr.ramp_overlap ...@@ -188,11 +180,7 @@ ramp_overlap(lr::LineReadout) = lr.ramp_overlap
inverse_fov(lr::LineReadout) = @. 1e3 * dwell_time(lr.adc) * gradient_strength(lr.trapezoid) inverse_fov(lr::LineReadout) = @. 1e3 * dwell_time(lr.adc) * gradient_strength(lr.trapezoid)
inverse_voxel_size(lr::LineReadout) = @. 1e3 * duration(lr.adc) * gradient_strength(lr.trapezoid) inverse_voxel_size(lr::LineReadout) = @. 1e3 * duration(lr.adc) * gradient_strength(lr.trapezoid)
for func in all_variables_symbols[:readout] get_readout(lr::LineReadout) = rl.adc
if func in (:inverse_fov, :slice_fov, :inverse_voxel_size, :slice_voxel_size, :ramp_overlap) get_gradient(lr::LineReadout) = rl.trapezoid
continue
end
Variables.$func(lr::LineReadout) = Variables.$func(lr.adc)
end
end end
\ No newline at end of file
...@@ -8,6 +8,7 @@ In addition this defines: ...@@ -8,6 +8,7 @@ In addition this defines:
- [`VariableNotAvailable`](@ref): error raised if variable is not defined for specific [`AbstractBlock`](@ref). - [`VariableNotAvailable`](@ref): error raised if variable is not defined for specific [`AbstractBlock`](@ref).
- [`set_simple_constraints`](@ref): call [`apply_simple_constraint`](@ref) for each keyword argument. - [`set_simple_constraints`](@ref): call [`apply_simple_constraint`](@ref) for each keyword argument.
- [`apply_simple_constraint`](@ref): set a simple equality constraint. - [`apply_simple_constraint`](@ref): set a simple equality constraint.
- [`get_pulse`](@ref)/[`get_gradient`](@ref)/[`get_readout`](@ref): Used to get the pulse/gradient/readout part of a building block
""" """
module Variables module Variables
import JuMP: @variable, Model, @objective, objective_function, value, AbstractJuMPScalar import JuMP: @variable, Model, @objective, objective_function, value, AbstractJuMPScalar
...@@ -133,6 +134,33 @@ function get_free_variable(::Val{:max}) ...@@ -133,6 +134,33 @@ function get_free_variable(::Val{:max})
return var return var
end end
"""
get_pulse(building_block)]
Get the pulse played out during the building block.
Any `pulse` variables not explicitly defined for this building block will be passed on to the pulse.
"""
function get_pulse end
"""
get_gradient(building_block)]
Get the gradient played out during the building block.
Any `gradient` variables not explicitly defined for this building block will be passed on to the gradient.
"""
function get_gradient end
"""
get_readout(building_block)]
Get the readout played out during the building block.
Any `readout` variables not explicitly defined for this building block will be passed on to the readout.
"""
function get_gradient end
""" """
bmat_gradient(gradient::GradientBlock, qstart=(0, 0, 0)) bmat_gradient(gradient::GradientBlock, qstart=(0, 0, 0))
...@@ -165,29 +193,38 @@ function Base.showerror(io::IO, e::VariableNotAvailable) ...@@ -165,29 +193,38 @@ function Base.showerror(io::IO, e::VariableNotAvailable)
end end
for variable_func in keys(variables) for (target_name, all_vars) in pairs(all_variables_symbols)
if variable_func in [:qval_square, :qval] for variable_func in keys(all_vars)
continue if variable_func in [:qval_square, :qval]
end continue
@eval function Variables.$variable_func(bb::BuildingBlock) end
if Variables.$variable_func in keys(alternative_variables) @eval function Variables.$variable_func(bb::BuildingBlock)
alt_var, forward, backward, _ = alternative_variables[Variables.$variable_func]
try try
value = alt_var(bb) if Variables.$variable_func in keys(alternative_variables)
if value isa Number alt_var, forward, backward, _ = alternative_variables[Variables.$variable_func]
return backward(value) try
elseif value isa AbstractArray{<:Number} value = alt_var(bb)
return backward.(value) if value isa Number
return backward(value)
elseif value isa AbstractArray{<:Number}
return backward.(value)
end
catch e
if e isa VariableNotAvailable
throw(VariableNotAvailable(typeof(bb), Variables.$variable_func))
end
rethrow()
end
throw(VariableNotAvailable(typeof(bb), Variables.$variable_func, alt_var))
end end
throw(VariableNotAvailable(typeof(bb), Variables.$variable_func))
catch e catch e
if e isa VariableNotAvailable if e isa VariableNotAvailable && hasmethod(get_$(target_name), Tuple(typeof(bb))) && $(target_name) in (:pulse, :readout)
throw(VariableNotAvailable(typeof(bb), Variables.$variable_func)) return Variables.$variable_func(get_$(target_name)(bb))
end end
rethrow() rethrow()
end end
throw(VariableNotAvailable(typeof(bb), Variables.$variable_func, alt_var))
end end
throw(VariableNotAvailable(typeof(bb), Variables.$variable_func))
end end
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment