From 5b7224f0534a41a2ce83d4a63147f983dfa0fb73 Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk> Date: Mon, 5 Feb 2024 16:39:37 +0000 Subject: [PATCH] Automatically detect variables for each buildingblock --- src/building_blocks.jl | 117 ++++++++++++------ src/gradients/changing_gradient_blocks.jl | 2 - src/gradients/constant_gradient_blocks.jl | 2 - src/gradients/instant_gradients.jl | 1 - .../gradient_pulses/spoilt_slice_selects.jl | 2 - .../gradient_pulses/trapezoid_gradients.jl | 13 +- .../gradient_readouts/single_lines.jl | 2 - src/pulses/constant_pulses.jl | 2 - src/pulses/fixed_pulses.jl | 2 - src/pulses/instant_pulses.jl | 1 - src/pulses/sinc_pulses.jl | 1 - src/readouts/ADCs.jl | 2 - src/readouts/instant_readouts.jl | 1 - src/sequences.jl | 1 - src/variables.jl | 66 +++++++--- src/wait.jl | 3 - 16 files changed, 133 insertions(+), 85 deletions(-) diff --git a/src/building_blocks.jl b/src/building_blocks.jl index 147fe56..252b4b7 100644 --- a/src/building_blocks.jl +++ b/src/building_blocks.jl @@ -1,7 +1,7 @@ module BuildingBlocks import JuMP: value, Model, @constraint, @objective, objective_function, AbstractJuMPScalar import Printf: @sprintf -import ..Variables: variables, start_time, duration, end_time, gradient_strength, slew_rate, effective_time, VariableType, qval_square +import ..Variables: Variables, variables, start_time, duration, end_time, gradient_strength, slew_rate, effective_time, VariableType, alternative_variables import ..BuildSequences: global_model, global_scanner, fixed import ..Scanners: Scanner @@ -100,13 +100,49 @@ Function used internally to convert a wide variety of objects into [`BuildingBlo to_block(bb::BuildingBlock) = bb - """ - variables(building_block) + VariableNotAvailable(building_block, variable, alt_variable) -Returns a list of function that can be called to constrain the `building_block`. +Exception raised when a variable function does not support a specific `BuildingBlock`. """ -variables(bb::BuildingBlock) = variables(typeof(bb)) +mutable struct VariableNotAvailable <: Exception + bb :: Type{<:BuildingBlock} + variable :: Function + alt_variable :: Union{Nothing, Function} +end +VariableNotAvailable(bb::Type{<:BuildingBlock}, variable::Function) = VariableNotAvailable(bb, variable, nothing) + +function Base.showerror(io::IO, e::VariableNotAvailable) + if isnothing(e.alt_variable) + print(io, e.variable, " is not available for block of type ", e.bb, ".") + else + print(io, e.variable, " is not available for block of type ", e.bb, ". Please use ", e.alt_variable, " instead to set any contsraints or objective functions.") + end +end + + +for variable_func in keys(variables) + @eval function Variables.$variable_func(bb::BuildingBlock) + if Variables.$variable_func in keys(alternative_variables) + alt_var, forward, backward, _ = alternative_variables[Variables.$variable_func] + try + value = alt_var(bb) + catch e + if e isa VariableNotAvailable + throw(VariableNotAvailable(typeof(bb), variable_func)) + end + rethrow() + end + if value isa Number + return backward(value) + elseif value isa AbstractArray{<:Number} + return backward.(value) + end + throw(VariableNotAvailable(typeof(bb), variable_func, alt_var)) + end + throw(VariableNotAvailable(typeof(bb), variable_func)) + end +end struct BuildingBlockPrinter{T<:BuildingBlock} @@ -141,7 +177,6 @@ _robust_value(possible_tuple::Tuple) = _robust_value([possible_tuple...]) function Base.show(io::IO, printer::BuildingBlockPrinter) block = printer.bb print(io, string(typeof(block)), "(") - variable_names = nameof.(variables(block)) printed_duration = false if !isnothing(printer.start_time) print(io, "t=", @sprintf("%.3g", printer.start_time)) @@ -167,11 +202,18 @@ function Base.show(io::IO, printer::BuildingBlockPrinter) print(io, name, "=", repr(getproperty(block, name)), ", ") end - for fn in variables(block) + for fn in values(variables) if printed_duration && fn == duration continue end - numeric_value = _robust_value(fn(block)) + try + numeric_value = _robust_value(fn(block)) + catch e + if e isa VariableNotAvailable + continue + end + rethrow() + end if isnothing(numeric_value) continue end @@ -196,29 +238,25 @@ If set to a numeric value, a constraint will be added to fix the function value If set to `:min` or `:max`, minimising or maximising this function will be added to the cost function. """ function set_simple_constraints!(block::BuildingBlock, kwargs) - to_funcs = Dict(nameof(fn) => fn for fn in variables(block)) - - invert_value(value::VariableType) = 1 / value - invert_value(value::Symbol) = invert_value(Val(value)) - invert_value(::Val{:min}) = Val(:max) - invert_value(::Val{:max}) = Val(:min) - invert_value(value::AbstractVector) = invert_value.(value) - invert_value(value) = value - for (key, value) in kwargs - if key in keys(to_funcs) - apply_simple_constraint!(to_funcs[key](block), value) - else - if key == :qval - apply_simple_constraint!(to_funcs[:qval_square](block), value isa VariableType ? value^2 : value) - elseif key == :slice_thickness && :inverse_slice_thickness in keys(to_funcs) - apply_simple_constraint!(to_funcs[:inverse_slice_thickness](block), invert_value(value)) - elseif key == :bandwidth && :inverse_bandwidth in keys(to_funcs) - apply_simple_constraint!(to_funcs[:inverse_bandwidth](block), invert_value(value)) - else - error("Trying to set an unrecognised variable $key.") + if key in keys(alternative_variables) + alt_var, forward, backward, to_invert = alternative_variables[Variables.$variable_func] + invert_value(value::VariableType) = forward(value) + invert_value(value::Symbol) = invert_value(Val(value)) + invert_value(::Val{:min}) = to_invert ? Val(:max) : Val(:min) + invert_value(::Val{:max}) = to_invert ? Val(:min) : Val(:max) + invert_value(value::AbstractVector) = invert_value.(value) + invert_value(value) = value + try + apply_simple_constraint!(alt_var(block), invert_value(value)) + return + catch e + if !(e isa VariableNotAvailable) + rethrow() + end end end + apply_simple_constraint!(variables[key](block), value) end nothing end @@ -245,10 +283,9 @@ apply_simple_constraint!(variable::Number, value::Number) = @assert variable ≈ """ - match_blocks!(block1, block2[, property_list]) + match_blocks!(block1, block2, property_list) Matches the listed variables between two [`BuildingBlock`](@ref) objects. -By default all shared variables (i.e., those with the same name) are matched. """ function match_blocks!(block1::BuildingBlock, block2::BuildingBlock, property_list) for fn in property_list @@ -256,11 +293,6 @@ function match_blocks!(block1::BuildingBlock, block2::BuildingBlock, property_li end end -function match_blocks!(block1::BuildingBlock, block2::BuildingBlock) - property_list = intersect(variables(block1), variables(block2)) - match_blocks!(block1, block2, property_list) -end - """ scanner_constraints!(building_block[, scanner]) @@ -280,7 +312,7 @@ end function scanner_constraints!(building_block::BuildingBlock, scanner::Scanner, func::Function) model = global_model() - if func in variables(building_block) + try # apply constraint at this level res_bb = func(building_block) if res_bb isa AbstractVector @@ -299,10 +331,15 @@ function scanner_constraints!(building_block::BuildingBlock, scanner::Scanner, f @constraint model res_bb <= func(scanner) @constraint model res_bb >= -func(scanner) end - elseif building_block isa ContainerBlock - # apply constraints at lower level - for (_, child_block) in get_children_blocks(building_block) - scanner_constraints!(child_block, scanner, func) + + catch e + if !(e isa VariableNotAvailable) + rethrow() + end + if building_block isa ContainerBlock + for (_, child_block) in get_children_blocks(building_block) + scanner_constraints!(child_block, scanner, func) + end end end end diff --git a/src/gradients/changing_gradient_blocks.jl b/src/gradients/changing_gradient_blocks.jl index 5ecba8a..b10699b 100644 --- a/src/gradients/changing_gradient_blocks.jl +++ b/src/gradients/changing_gradient_blocks.jl @@ -63,8 +63,6 @@ function bmat_gradient(cgb::ChangingGradientBlock) end -variables(::Type{<:ChangingGradientBlock}) = [duration, slew_rate, gradient_strength, qvec] - """ split_gradient(constant/changing_gradient_block, times...) diff --git a/src/gradients/constant_gradient_blocks.jl b/src/gradients/constant_gradient_blocks.jl index 858081c..eaf7e1c 100644 --- a/src/gradients/constant_gradient_blocks.jl +++ b/src/gradients/constant_gradient_blocks.jl @@ -48,8 +48,6 @@ function bmat_gradient(cgb::ConstantGradientBlock, qstart) ) end -variables(::Type{<:ConstantGradientBlock}) = [duration, gradient_strength, qvec] - function split_gradient(cgb::ConstantGradientBlock, times::VariableType...) durations = [times[1], [t[2] - t[1] for t in zip(times[1:end-1], times[2:end])]..., duration(cgb) - times[end]] return [ConstantGradientBlock(cgb.gradient_strength, d, cgb.rotate, cgb.scale) for d in durations] diff --git a/src/gradients/instant_gradients.jl b/src/gradients/instant_gradients.jl index 8658b38..08026bd 100644 --- a/src/gradients/instant_gradients.jl +++ b/src/gradients/instant_gradients.jl @@ -78,7 +78,6 @@ end qvec(instant::InstantGradientBlock) = instant.qvec bmat_gradient(::InstantGradientBlock, qstart=nothing) = zeros(3, 3) duration(instant::InstantGradientBlock) = 0. -variables(::Type{<:InstantGradientBlock}) = [qvec, qval] end \ No newline at end of file diff --git a/src/overlapping/gradient_pulses/spoilt_slice_selects.jl b/src/overlapping/gradient_pulses/spoilt_slice_selects.jl index 9f6127b..e467c9f 100644 --- a/src/overlapping/gradient_pulses/spoilt_slice_selects.jl +++ b/src/overlapping/gradient_pulses/spoilt_slice_selects.jl @@ -105,7 +105,5 @@ function all_gradient_strengths(spoilt::SpoiltSliceSelect) return [grad1, grad2, grad3] end -variables(::Type{<:SpoiltSliceSelect}) = [duration, qvec, slew_rate, inverse_slice_thickness, all_gradient_strengths, rise_time, flat_time, fall_time] - end \ No newline at end of file diff --git a/src/overlapping/gradient_pulses/trapezoid_gradients.jl b/src/overlapping/gradient_pulses/trapezoid_gradients.jl index ac9e9b9..f7fad65 100644 --- a/src/overlapping/gradient_pulses/trapezoid_gradients.jl +++ b/src/overlapping/gradient_pulses/trapezoid_gradients.jl @@ -6,6 +6,7 @@ module TrapezoidGradients import JuMP: @constraint, @variable, VariableRef, value import StaticArrays: SVector import LinearAlgebra: norm +import ....BuildingBlocks: VariableNotAvailable import ....Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType, inverse_slice_thickness, inverse_bandwidth, effective_time, qval_square import ....BuildingBlocks: duration, set_simple_constraints!, RFPulseBlock, scanner_constraints! import ....BuildSequences: global_model @@ -129,15 +130,11 @@ slew_rate(g::TrapezoidGradient) = g.slew_rate δ(g::TrapezoidGradient) = rise_time(g) + flat_time(g) duration(g::TrapezoidGradient) = 2 * rise_time(g) + flat_time(g) qvec(g::TrapezoidGradient, ::Nothing, ::Nothing) = δ(g) .* gradient_strength(g) .* 2π -inverse_slice_thickness(g::TrapezoidGradient) = isnothing(g.pulse) ? nothing : inverse_bandwidth(g) .* gradient_strength(g) .* 1000 - -function variables(tg::TrapezoidGradient) - list = [slew_rate, qvec, δ, gradient_strength, duration, rise_time, flat_time, qval_square] - if !isnothing(tg.pulse) - push!(list, inverse_slice_thickness) +function inverse_slice_thickness(g::TrapezoidGradient) + if isnothing(g.pulse) + throw(VariableNotAvailable(typeof(g), inverse_slice_thickness)) end - return list + return inverse_bandwidth(g) .* gradient_strength(g) .* 1000 end - end \ No newline at end of file diff --git a/src/overlapping/gradient_readouts/single_lines.jl b/src/overlapping/gradient_readouts/single_lines.jl index b3f1021..4a174d1 100644 --- a/src/overlapping/gradient_readouts/single_lines.jl +++ b/src/overlapping/gradient_readouts/single_lines.jl @@ -33,6 +33,4 @@ nsamples(sl::SingleLine) = nsamples(sl.adc) resolution(sl::SingleLine) = resolution(sl.adc) duration(sl::SingleLine) = duration(sl.grad) -variables(::Type{<:SingleLine}) = [dwell_time, gradient_strenth, fov_inverse, voxel_size_inverse, resolution, duration] - end \ No newline at end of file diff --git a/src/pulses/constant_pulses.jl b/src/pulses/constant_pulses.jl index 07ac767..1cc4318 100644 --- a/src/pulses/constant_pulses.jl +++ b/src/pulses/constant_pulses.jl @@ -43,8 +43,6 @@ flip_angle(pulse::ConstantPulse) = amplitude(pulse) * duration(pulse) * 360 inverse_bandwidth(pulse::ConstantPulse) = duration(pulse) * 4 effective_time(pulse::ConstantPulse) = duration(pulse) / 2 -variables(::Type{<:ConstantPulse}) = [amplitude, duration, phase, frequency, flip_angle, inverse_bandwidth] - function fixed(block::ConstantPulse) d = value(duration(block)) final_phase = phase(block) + d * frequency(block) * 360 diff --git a/src/pulses/fixed_pulses.jl b/src/pulses/fixed_pulses.jl index bbc6332..3421e5e 100644 --- a/src/pulses/fixed_pulses.jl +++ b/src/pulses/fixed_pulses.jl @@ -32,8 +32,6 @@ function FixedPulse(time::AbstractVector{<:Number}, amplitude::AbstractVector{<: return FixedPulse(time, amplitude, (time .- time[1]) .* (frequency * 360) .+ phase) end -variables(::Type{<:FixedPulse}) = [duration, flip_angle, effective_time] - duration(fp::FixedPulse) = maximum(fp.time) amplitude(fp::FixedPulse) = maximum(abs.(fp.amplitude)) effective_time(pulse::FixedPulse) = pulse.time[argmax(abs.(pulse.amplitude))] diff --git a/src/pulses/instant_pulses.jl b/src/pulses/instant_pulses.jl index bb8e195..a4fe539 100644 --- a/src/pulses/instant_pulses.jl +++ b/src/pulses/instant_pulses.jl @@ -24,7 +24,6 @@ end flip_angle(instant::InstantRFPulseBlock) = instant.flip_angle phase(instant::InstantRFPulseBlock) = instant.phase duration(::InstantRFPulseBlock) = 0. -variables(::Type{<:InstantRFPulseBlock}) = [flip_angle, phase] effective_time(::InstantRFPulseBlock) = 0. inverse_bandwidth(::InstantRFPulseBlock) = 0. diff --git a/src/pulses/sinc_pulses.jl b/src/pulses/sinc_pulses.jl index 257e4ef..0e0a2c4 100644 --- a/src/pulses/sinc_pulses.jl +++ b/src/pulses/sinc_pulses.jl @@ -99,7 +99,6 @@ frequency(pulse::SincPulse) = pulse.frequency flip_angle(pulse::SincPulse) = (pulse.nlobe_integral(N_left(pulse)) + pulse.nlobe_integral(N_right(pulse))) * amplitude(pulse) * lobe_duration(pulse) * 360 lobe_duration(pulse::SincPulse) = pulse.lobe_duration inverse_bandwidth(pulse::SincPulse) = lobe_duration(pulse) -variables(::Type{<:SincPulse}) = [amplitude, N_left, N_right, duration, phase, frequency, flip_angle, lobe_duration, inverse_bandwidth] effective_time(pulse::SincPulse) = N_left(pulse) * lobe_duration(pulse) function fixed(block::SincPulse) diff --git a/src/readouts/ADCs.jl b/src/readouts/ADCs.jl index a342cdc..f84ce17 100644 --- a/src/readouts/ADCs.jl +++ b/src/readouts/ADCs.jl @@ -50,8 +50,6 @@ time_to_center(adc::ADC) = adc.time_to_center effective_time(adc::ADC) = time_to_center(adc) resolution(adc::ADC) = adc.resolution -variables(::Type{<:ADC}) = [nsamples, dwell_time, duration, time_to_center, resolution, oversample] - function fixed(adc::ADC) # round nsamples during fixing nsamples = Int(round(adc.nsamples)) diff --git a/src/readouts/instant_readouts.jl b/src/readouts/instant_readouts.jl index 9bf9a29..b853791 100644 --- a/src/readouts/instant_readouts.jl +++ b/src/readouts/instant_readouts.jl @@ -12,7 +12,6 @@ It has no parameters or properties to set. struct InstantReadout <: BuildingBlock end -variables(::Type{<:InstantReadout}) = [] to_block(::Type{<:InstantReadout}) = InstantReadout() duration(::InstantReadout) = 0. effective_time(::InstantReadout) = 0. diff --git a/src/sequences.jl b/src/sequences.jl index 95f3eb8..c154238 100644 --- a/src/sequences.jl +++ b/src/sequences.jl @@ -54,7 +54,6 @@ start_time(seq::Sequence, index::Integer) = isone(index) ? start_time(seq) : (st duration(seq::Sequence) = end_time(seq, length(seq)) TR(seq::Sequence) = seq.TR -variables(::Type{<:Sequence}) = [TR] # print timings when printing sequences Base.show(io::IO, seq::Sequence) = print(io, BuildingBlockPrinter(seq, 0., 0)) diff --git a/src/variables.jl b/src/variables.jl index 1e80adc..3e43d15 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -16,12 +16,10 @@ all_variables_symbols = [ :amplitude => "The maximum amplitude of an RF pulse in kHz", :phase => "The angle of the phase of an RF pulse in KHz", :frequency => "The off-resonance frequency of an RF pulse (relative to the Larmor frequency of water) in KHz", - :bandwidth => "Bandwidth of the RF pulse in kHz. If you are going to divide by the bandwidth, it can be more efficient to use the [`inverse_bandwidth`](@ref).", - :inverse_bandwidth => "Inverse of the [`bandwidth`](@ref) of the RF pulse in ms", + :bandwidth => "Bandwidth of the RF pulse in kHz. To set constraints it is often better to use [`inverse_bandwidth`](@ref)", :N_left => "The number of zero crossings of the RF pulse before the main peak", :N_right => "The number of zero crossings of the RF pulse after the main peak", - :slice_thickness => "Slice thickness of an RF pulse that is active during a gradient in mm.", - :inverse_slice_thickness => "Inverse of the [`slice_thickness`](@ref) in 1/mm.", + :slice_thickness => "Slice thickness of an RF pulse that is active during a gradient in mm. To set constraints it is often better to use [`inverse_slice_thickness`](@ref).", ], # gradients @@ -44,7 +42,7 @@ all_variables_symbols = [ ] ] -symbol_to_func = Dict{Symbol, Function}() +variables = Dict{Symbol, Function}() @@ -56,30 +54,68 @@ for (block_symbol, all_functions) in all_variables_symbols @doc $as_string $func_symbol $func_symbol end - symbol_to_func[func_symbol] = new_func + variables[func_symbol] = new_func end end +# helper functions """ - variables(building_block) - variables() + inverse_slice_thickness(pulse) -Returns all functions representing properties of a [`BuildingBlock`](@ref) object. +Inverse of [`slice_thickness`](@ref) in ms. + +It is defined separately from `slice_thickness` to avoid unnecessary divisions in any constraint. +""" +function inverse_slice_thickness end + +""" + inverse_bandwidth(pulse) + +Inverse of [`bandwidth`](@ref) in 1/mm. + +It is defined separately from `bandwidth` to avoid unnecessary divisions in any constraint. +""" +function inverse_bandwidth end + +""" + inverse_fov(readout) + +Inverse of [`fov`](@ref) in 1/mm. + +It is defined separately from `fov` to avoid unnecessary divisions in any constraint. """ -variables() = [values(symbol_to_func)...] +function inverse_fov end +""" + inverse_voxel_size(readout) -# Some universal truths -slice_thickness(bb) = inv(inverse_slice_thickness(bb)) -bandwidth(bb) = inv(inverse_bandwidth(bb)) +Inverse of [`voxel_size`](@ref) in 1/mm. + +It is defined separately from `voxel_size` to avoid unnecessary divisions in any constraint. +""" +function inverse_voxel_size end -function qval_square(bb; kwargs...) - vec = qvec(bb; kwargs...) +""" + qval_square(gradient) + +Square of [`qval`](@ref) in rad/um. + +It is defined separately from `qval` to avoid unnecessary square root in any constraint. +""" +function qval_square(bb, args...; kwargs...) + vec = qvec(bb, args...; kwargs...) return vec[1]^2 + vec[2]^2 + vec[3]^2 end qval(bb; kwargs...) = sqrt(qval_square(bb)) +alternative_variables = Dict( + qval => (qval_square, n->n^2, sqrt, false), + slice_thickness => (inverse_slice_thickness, inv, inv, true), + bandwidth => (inverse_bandwidth, inv, inv, true), + fov => (inverse_fov, inv, inv, true), + voxel_size => (inverse_voxel_size, inv, inv, true), +) # These functions are more fully defined in building_blocks.jl diff --git a/src/wait.jl b/src/wait.jl index 72360ed..c3b4a68 100644 --- a/src/wait.jl +++ b/src/wait.jl @@ -38,9 +38,6 @@ Converts object into a [`WaitBlock`](@ref). """ to_block(time::Union{VariableType, Symbol, Nothing, Val{:min}, Val{:max}}) = WaitBlock(time) - -variables(::Type{WaitBlock}) = [duration] - duration(wb::WaitBlock) = wb.duration end \ No newline at end of file -- GitLab