From 5b7224f0534a41a2ce83d4a63147f983dfa0fb73 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Mon, 5 Feb 2024 16:39:37 +0000
Subject: [PATCH] Automatically detect variables for each buildingblock

---
 src/building_blocks.jl                        | 117 ++++++++++++------
 src/gradients/changing_gradient_blocks.jl     |   2 -
 src/gradients/constant_gradient_blocks.jl     |   2 -
 src/gradients/instant_gradients.jl            |   1 -
 .../gradient_pulses/spoilt_slice_selects.jl   |   2 -
 .../gradient_pulses/trapezoid_gradients.jl    |  13 +-
 .../gradient_readouts/single_lines.jl         |   2 -
 src/pulses/constant_pulses.jl                 |   2 -
 src/pulses/fixed_pulses.jl                    |   2 -
 src/pulses/instant_pulses.jl                  |   1 -
 src/pulses/sinc_pulses.jl                     |   1 -
 src/readouts/ADCs.jl                          |   2 -
 src/readouts/instant_readouts.jl              |   1 -
 src/sequences.jl                              |   1 -
 src/variables.jl                              |  66 +++++++---
 src/wait.jl                                   |   3 -
 16 files changed, 133 insertions(+), 85 deletions(-)

diff --git a/src/building_blocks.jl b/src/building_blocks.jl
index 147fe56..252b4b7 100644
--- a/src/building_blocks.jl
+++ b/src/building_blocks.jl
@@ -1,7 +1,7 @@
 module BuildingBlocks
 import JuMP: value, Model, @constraint, @objective, objective_function, AbstractJuMPScalar
 import Printf: @sprintf
-import ..Variables: variables, start_time, duration, end_time, gradient_strength, slew_rate, effective_time, VariableType, qval_square
+import ..Variables: Variables, variables, start_time, duration, end_time, gradient_strength, slew_rate, effective_time, VariableType, alternative_variables
 import ..BuildSequences: global_model, global_scanner, fixed
 import ..Scanners: Scanner
 
@@ -100,13 +100,49 @@ Function used internally to convert a wide variety of objects into [`BuildingBlo
 to_block(bb::BuildingBlock) = bb
 
 
-
 """
-    variables(building_block)
+    VariableNotAvailable(building_block, variable, alt_variable)
 
-Returns a list of function that can be called to constrain the `building_block`.
+Exception raised when a variable function does not support a specific `BuildingBlock`.
 """
-variables(bb::BuildingBlock) = variables(typeof(bb))
+mutable struct VariableNotAvailable <: Exception
+    bb :: Type{<:BuildingBlock}
+    variable :: Function
+    alt_variable :: Union{Nothing, Function}
+end
+VariableNotAvailable(bb::Type{<:BuildingBlock}, variable::Function) = VariableNotAvailable(bb, variable, nothing)
+
+function Base.showerror(io::IO, e::VariableNotAvailable)
+    if isnothing(e.alt_variable)
+        print(io, e.variable, " is not available for block of type ", e.bb, ".")
+    else
+        print(io, e.variable, " is not available for block of type ", e.bb, ". Please use ", e.alt_variable, " instead to set any contsraints or objective functions.")
+    end
+end
+
+
+for variable_func in keys(variables)
+    @eval function Variables.$variable_func(bb::BuildingBlock)
+        if Variables.$variable_func in keys(alternative_variables)
+            alt_var, forward, backward, _ = alternative_variables[Variables.$variable_func]
+            try
+                value = alt_var(bb)
+            catch e
+                if e isa VariableNotAvailable
+                    throw(VariableNotAvailable(typeof(bb), variable_func))
+                end
+                rethrow()
+            end
+            if value isa Number
+                return backward(value)
+            elseif value isa AbstractArray{<:Number}
+                return backward.(value)
+            end
+            throw(VariableNotAvailable(typeof(bb), variable_func, alt_var))
+        end
+        throw(VariableNotAvailable(typeof(bb), variable_func))
+    end
+end
 
 
 struct BuildingBlockPrinter{T<:BuildingBlock}
@@ -141,7 +177,6 @@ _robust_value(possible_tuple::Tuple) = _robust_value([possible_tuple...])
 function Base.show(io::IO, printer::BuildingBlockPrinter)
     block = printer.bb
     print(io, string(typeof(block)), "(")
-    variable_names = nameof.(variables(block))
     printed_duration = false
     if !isnothing(printer.start_time)
         print(io, "t=", @sprintf("%.3g", printer.start_time))
@@ -167,11 +202,18 @@ function Base.show(io::IO, printer::BuildingBlockPrinter)
         print(io, name, "=", repr(getproperty(block, name)), ", ")
     end
 
-    for fn in variables(block)
+    for fn in values(variables)
         if printed_duration && fn == duration
             continue
         end
-        numeric_value = _robust_value(fn(block))
+        try
+            numeric_value = _robust_value(fn(block))
+        catch e
+            if e isa VariableNotAvailable
+                continue
+            end
+            rethrow()
+        end
         if isnothing(numeric_value)
             continue
         end
@@ -196,29 +238,25 @@ If set to a numeric value, a constraint will be added to fix the function value
 If set to `:min` or `:max`, minimising or maximising this function will be added to the cost function.
 """
 function set_simple_constraints!(block::BuildingBlock, kwargs)
-    to_funcs = Dict(nameof(fn) => fn for fn in variables(block))
-
-    invert_value(value::VariableType) = 1 / value
-    invert_value(value::Symbol) = invert_value(Val(value))
-    invert_value(::Val{:min}) = Val(:max)
-    invert_value(::Val{:max}) = Val(:min)
-    invert_value(value::AbstractVector) = invert_value.(value)
-    invert_value(value) = value
-
     for (key, value) in kwargs
-        if key in keys(to_funcs)
-            apply_simple_constraint!(to_funcs[key](block), value)
-        else
-            if key == :qval
-                apply_simple_constraint!(to_funcs[:qval_square](block), value isa VariableType ? value^2 : value)
-            elseif key == :slice_thickness && :inverse_slice_thickness in keys(to_funcs)
-                apply_simple_constraint!(to_funcs[:inverse_slice_thickness](block), invert_value(value))
-            elseif key == :bandwidth && :inverse_bandwidth in keys(to_funcs)
-                apply_simple_constraint!(to_funcs[:inverse_bandwidth](block), invert_value(value))
-            else
-                error("Trying to set an unrecognised variable $key.")
+        if key in keys(alternative_variables)
+            alt_var, forward, backward, to_invert = alternative_variables[Variables.$variable_func]
+            invert_value(value::VariableType) = forward(value)
+            invert_value(value::Symbol) = invert_value(Val(value))
+            invert_value(::Val{:min}) = to_invert ? Val(:max) : Val(:min)
+            invert_value(::Val{:max}) = to_invert ? Val(:min) : Val(:max)
+            invert_value(value::AbstractVector) = invert_value.(value)
+            invert_value(value) = value
+            try
+                apply_simple_constraint!(alt_var(block), invert_value(value))
+                return
+            catch e
+                if !(e isa VariableNotAvailable)
+                    rethrow()
+                end
             end
         end
+        apply_simple_constraint!(variables[key](block), value)
     end
     nothing
 end
@@ -245,10 +283,9 @@ apply_simple_constraint!(variable::Number, value::Number) = @assert variable ≈
 
 
 """
-    match_blocks!(block1, block2[, property_list])
+    match_blocks!(block1, block2, property_list)
 
 Matches the listed variables between two [`BuildingBlock`](@ref) objects.
-By default all shared variables (i.e., those with the same name) are matched.
 """
 function match_blocks!(block1::BuildingBlock, block2::BuildingBlock, property_list)
     for fn in property_list
@@ -256,11 +293,6 @@ function match_blocks!(block1::BuildingBlock, block2::BuildingBlock, property_li
     end
 end
 
-function match_blocks!(block1::BuildingBlock, block2::BuildingBlock)
-    property_list = intersect(variables(block1), variables(block2))
-    match_blocks!(block1, block2, property_list)
-end
-
 """
     scanner_constraints!(building_block[, scanner])
 
@@ -280,7 +312,7 @@ end
 
 function scanner_constraints!(building_block::BuildingBlock, scanner::Scanner, func::Function)
     model = global_model()
-    if func in variables(building_block)
+    try
         # apply constraint at this level
         res_bb = func(building_block)
         if res_bb isa AbstractVector
@@ -299,10 +331,15 @@ function scanner_constraints!(building_block::BuildingBlock, scanner::Scanner, f
             @constraint model res_bb <= func(scanner)
             @constraint model res_bb >= -func(scanner)
         end
-    elseif building_block isa ContainerBlock
-        # apply constraints at lower level
-        for (_, child_block) in get_children_blocks(building_block)
-            scanner_constraints!(child_block, scanner, func)
+
+    catch e
+        if !(e isa VariableNotAvailable)
+            rethrow()
+        end
+        if building_block isa ContainerBlock
+            for (_, child_block) in get_children_blocks(building_block)
+                scanner_constraints!(child_block, scanner, func)
+            end
         end
     end
 end
diff --git a/src/gradients/changing_gradient_blocks.jl b/src/gradients/changing_gradient_blocks.jl
index 5ecba8a..b10699b 100644
--- a/src/gradients/changing_gradient_blocks.jl
+++ b/src/gradients/changing_gradient_blocks.jl
@@ -63,8 +63,6 @@ function bmat_gradient(cgb::ChangingGradientBlock)
 end
 
 
-variables(::Type{<:ChangingGradientBlock}) = [duration, slew_rate, gradient_strength, qvec]
-
 """
     split_gradient(constant/changing_gradient_block, times...)
 
diff --git a/src/gradients/constant_gradient_blocks.jl b/src/gradients/constant_gradient_blocks.jl
index 858081c..eaf7e1c 100644
--- a/src/gradients/constant_gradient_blocks.jl
+++ b/src/gradients/constant_gradient_blocks.jl
@@ -48,8 +48,6 @@ function bmat_gradient(cgb::ConstantGradientBlock, qstart)
     )
 end
 
-variables(::Type{<:ConstantGradientBlock}) = [duration, gradient_strength, qvec]
-
 function split_gradient(cgb::ConstantGradientBlock, times::VariableType...)
     durations = [times[1], [t[2] - t[1] for t in zip(times[1:end-1], times[2:end])]..., duration(cgb) - times[end]]
     return [ConstantGradientBlock(cgb.gradient_strength, d, cgb.rotate, cgb.scale) for d in durations]
diff --git a/src/gradients/instant_gradients.jl b/src/gradients/instant_gradients.jl
index 8658b38..08026bd 100644
--- a/src/gradients/instant_gradients.jl
+++ b/src/gradients/instant_gradients.jl
@@ -78,7 +78,6 @@ end
 qvec(instant::InstantGradientBlock) = instant.qvec
 bmat_gradient(::InstantGradientBlock, qstart=nothing) = zeros(3, 3)
 duration(instant::InstantGradientBlock) = 0.
-variables(::Type{<:InstantGradientBlock}) = [qvec, qval]
 
 
 end
\ No newline at end of file
diff --git a/src/overlapping/gradient_pulses/spoilt_slice_selects.jl b/src/overlapping/gradient_pulses/spoilt_slice_selects.jl
index 9f6127b..e467c9f 100644
--- a/src/overlapping/gradient_pulses/spoilt_slice_selects.jl
+++ b/src/overlapping/gradient_pulses/spoilt_slice_selects.jl
@@ -105,7 +105,5 @@ function all_gradient_strengths(spoilt::SpoiltSliceSelect)
     return [grad1, grad2, grad3]
 end
 
-variables(::Type{<:SpoiltSliceSelect}) = [duration, qvec, slew_rate, inverse_slice_thickness, all_gradient_strengths, rise_time, flat_time, fall_time]
-
 
 end
\ No newline at end of file
diff --git a/src/overlapping/gradient_pulses/trapezoid_gradients.jl b/src/overlapping/gradient_pulses/trapezoid_gradients.jl
index ac9e9b9..f7fad65 100644
--- a/src/overlapping/gradient_pulses/trapezoid_gradients.jl
+++ b/src/overlapping/gradient_pulses/trapezoid_gradients.jl
@@ -6,6 +6,7 @@ module TrapezoidGradients
 import JuMP: @constraint, @variable, VariableRef, value
 import StaticArrays: SVector
 import LinearAlgebra: norm
+import ....BuildingBlocks: VariableNotAvailable
 import ....Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType, inverse_slice_thickness, inverse_bandwidth, effective_time, qval_square
 import ....BuildingBlocks: duration, set_simple_constraints!, RFPulseBlock, scanner_constraints!
 import ....BuildSequences: global_model
@@ -129,15 +130,11 @@ slew_rate(g::TrapezoidGradient) = g.slew_rate
 δ(g::TrapezoidGradient) = rise_time(g) + flat_time(g)
 duration(g::TrapezoidGradient) = 2 * rise_time(g) + flat_time(g)
 qvec(g::TrapezoidGradient, ::Nothing, ::Nothing) = δ(g) .* gradient_strength(g) .* 2π
-inverse_slice_thickness(g::TrapezoidGradient) = isnothing(g.pulse) ? nothing : inverse_bandwidth(g) .* gradient_strength(g) .* 1000
-
-function variables(tg::TrapezoidGradient) 
-    list = [slew_rate, qvec, δ, gradient_strength, duration, rise_time, flat_time, qval_square]
-    if !isnothing(tg.pulse)
-        push!(list, inverse_slice_thickness)
+function inverse_slice_thickness(g::TrapezoidGradient)
+    if isnothing(g.pulse)
+        throw(VariableNotAvailable(typeof(g), inverse_slice_thickness))
     end
-    return list
+    return inverse_bandwidth(g) .* gradient_strength(g) .* 1000
 end
 
-
 end
\ No newline at end of file
diff --git a/src/overlapping/gradient_readouts/single_lines.jl b/src/overlapping/gradient_readouts/single_lines.jl
index b3f1021..4a174d1 100644
--- a/src/overlapping/gradient_readouts/single_lines.jl
+++ b/src/overlapping/gradient_readouts/single_lines.jl
@@ -33,6 +33,4 @@ nsamples(sl::SingleLine) = nsamples(sl.adc)
 resolution(sl::SingleLine) = resolution(sl.adc)
 duration(sl::SingleLine) = duration(sl.grad)
 
-variables(::Type{<:SingleLine}) = [dwell_time, gradient_strenth, fov_inverse, voxel_size_inverse, resolution, duration]
-
 end
\ No newline at end of file
diff --git a/src/pulses/constant_pulses.jl b/src/pulses/constant_pulses.jl
index 07ac767..1cc4318 100644
--- a/src/pulses/constant_pulses.jl
+++ b/src/pulses/constant_pulses.jl
@@ -43,8 +43,6 @@ flip_angle(pulse::ConstantPulse) = amplitude(pulse) * duration(pulse) * 360
 inverse_bandwidth(pulse::ConstantPulse) = duration(pulse) * 4
 effective_time(pulse::ConstantPulse) = duration(pulse) / 2
 
-variables(::Type{<:ConstantPulse}) = [amplitude, duration, phase, frequency, flip_angle, inverse_bandwidth]
-
 function fixed(block::ConstantPulse)
     d = value(duration(block))
     final_phase = phase(block) + d * frequency(block) * 360
diff --git a/src/pulses/fixed_pulses.jl b/src/pulses/fixed_pulses.jl
index bbc6332..3421e5e 100644
--- a/src/pulses/fixed_pulses.jl
+++ b/src/pulses/fixed_pulses.jl
@@ -32,8 +32,6 @@ function FixedPulse(time::AbstractVector{<:Number}, amplitude::AbstractVector{<:
     return FixedPulse(time, amplitude, (time .- time[1]) .* (frequency * 360) .+ phase)
 end
 
-variables(::Type{<:FixedPulse}) = [duration, flip_angle, effective_time]
-
 duration(fp::FixedPulse) = maximum(fp.time)
 amplitude(fp::FixedPulse) = maximum(abs.(fp.amplitude))
 effective_time(pulse::FixedPulse) = pulse.time[argmax(abs.(pulse.amplitude))]
diff --git a/src/pulses/instant_pulses.jl b/src/pulses/instant_pulses.jl
index bb8e195..a4fe539 100644
--- a/src/pulses/instant_pulses.jl
+++ b/src/pulses/instant_pulses.jl
@@ -24,7 +24,6 @@ end
 flip_angle(instant::InstantRFPulseBlock) = instant.flip_angle
 phase(instant::InstantRFPulseBlock) = instant.phase
 duration(::InstantRFPulseBlock) = 0.
-variables(::Type{<:InstantRFPulseBlock}) = [flip_angle, phase]
 effective_time(::InstantRFPulseBlock) = 0.
 inverse_bandwidth(::InstantRFPulseBlock) = 0.
 
diff --git a/src/pulses/sinc_pulses.jl b/src/pulses/sinc_pulses.jl
index 257e4ef..0e0a2c4 100644
--- a/src/pulses/sinc_pulses.jl
+++ b/src/pulses/sinc_pulses.jl
@@ -99,7 +99,6 @@ frequency(pulse::SincPulse) = pulse.frequency
 flip_angle(pulse::SincPulse) = (pulse.nlobe_integral(N_left(pulse)) + pulse.nlobe_integral(N_right(pulse))) * amplitude(pulse) * lobe_duration(pulse) * 360
 lobe_duration(pulse::SincPulse) = pulse.lobe_duration
 inverse_bandwidth(pulse::SincPulse) = lobe_duration(pulse)
-variables(::Type{<:SincPulse}) = [amplitude, N_left, N_right, duration, phase, frequency, flip_angle, lobe_duration, inverse_bandwidth]
 effective_time(pulse::SincPulse) = N_left(pulse) * lobe_duration(pulse)
 
 function fixed(block::SincPulse)
diff --git a/src/readouts/ADCs.jl b/src/readouts/ADCs.jl
index a342cdc..f84ce17 100644
--- a/src/readouts/ADCs.jl
+++ b/src/readouts/ADCs.jl
@@ -50,8 +50,6 @@ time_to_center(adc::ADC) = adc.time_to_center
 effective_time(adc::ADC) = time_to_center(adc)
 resolution(adc::ADC) = adc.resolution
 
-variables(::Type{<:ADC}) = [nsamples, dwell_time, duration, time_to_center, resolution, oversample]
-
 function fixed(adc::ADC)
     # round nsamples during fixing
     nsamples = Int(round(adc.nsamples))
diff --git a/src/readouts/instant_readouts.jl b/src/readouts/instant_readouts.jl
index 9bf9a29..b853791 100644
--- a/src/readouts/instant_readouts.jl
+++ b/src/readouts/instant_readouts.jl
@@ -12,7 +12,6 @@ It has no parameters or properties to set.
 struct InstantReadout <: BuildingBlock
 end
 
-variables(::Type{<:InstantReadout}) = []
 to_block(::Type{<:InstantReadout}) = InstantReadout()
 duration(::InstantReadout) = 0.
 effective_time(::InstantReadout) = 0.
diff --git a/src/sequences.jl b/src/sequences.jl
index 95f3eb8..c154238 100644
--- a/src/sequences.jl
+++ b/src/sequences.jl
@@ -54,7 +54,6 @@ start_time(seq::Sequence, index::Integer) = isone(index) ? start_time(seq) : (st
 duration(seq::Sequence) = end_time(seq, length(seq))
 
 TR(seq::Sequence) = seq.TR
-variables(::Type{<:Sequence}) = [TR]
 
 # print timings when printing sequences
 Base.show(io::IO, seq::Sequence) = print(io, BuildingBlockPrinter(seq, 0., 0))
diff --git a/src/variables.jl b/src/variables.jl
index 1e80adc..3e43d15 100644
--- a/src/variables.jl
+++ b/src/variables.jl
@@ -16,12 +16,10 @@ all_variables_symbols = [
         :amplitude => "The maximum amplitude of an RF pulse in kHz",
         :phase => "The angle of the phase of an RF pulse in KHz",
         :frequency => "The off-resonance frequency of an RF pulse (relative to the Larmor frequency of water) in KHz",
-        :bandwidth => "Bandwidth of the RF pulse in kHz. If you are going to divide by the bandwidth, it can be more efficient to use the [`inverse_bandwidth`](@ref).",
-        :inverse_bandwidth => "Inverse of the [`bandwidth`](@ref) of the RF pulse in ms",
+        :bandwidth => "Bandwidth of the RF pulse in kHz. To set constraints it is often better to use [`inverse_bandwidth`](@ref)",
         :N_left => "The number of zero crossings of the RF pulse before the main peak",
         :N_right => "The number of zero crossings of the RF pulse after the main peak",
-        :slice_thickness => "Slice thickness of an RF pulse that is active during a gradient in mm.",
-        :inverse_slice_thickness => "Inverse of the [`slice_thickness`](@ref) in 1/mm.",
+        :slice_thickness => "Slice thickness of an RF pulse that is active during a gradient in mm. To set constraints it is often better to use [`inverse_slice_thickness`](@ref).",
     ],
 
     # gradients
@@ -44,7 +42,7 @@ all_variables_symbols = [
     ]
 ]
 
-symbol_to_func = Dict{Symbol, Function}()
+variables = Dict{Symbol, Function}()
 
 
 
@@ -56,30 +54,68 @@ for (block_symbol, all_functions) in all_variables_symbols
             @doc $as_string $func_symbol
             $func_symbol
         end
-        symbol_to_func[func_symbol] = new_func
+        variables[func_symbol] = new_func
     end
 end
 
 
+# helper functions
 """
-    variables(building_block)
-    variables()
+    inverse_slice_thickness(pulse)
 
-Returns all functions representing properties of a [`BuildingBlock`](@ref) object.
+Inverse of [`slice_thickness`](@ref) in ms.
+
+It is defined separately from `slice_thickness` to avoid unnecessary divisions in any constraint.
+"""
+function inverse_slice_thickness end
+
+"""
+    inverse_bandwidth(pulse)
+
+Inverse of [`bandwidth`](@ref) in 1/mm.
+
+It is defined separately from `bandwidth` to avoid unnecessary divisions in any constraint.
+"""
+function inverse_bandwidth end
+
+"""
+    inverse_fov(readout)
+
+Inverse of [`fov`](@ref) in 1/mm.
+
+It is defined separately from `fov` to avoid unnecessary divisions in any constraint.
 """
-variables() = [values(symbol_to_func)...]
+function inverse_fov end
 
+"""
+    inverse_voxel_size(readout)
 
-# Some universal truths
-slice_thickness(bb) = inv(inverse_slice_thickness(bb))
-bandwidth(bb) = inv(inverse_bandwidth(bb))
+Inverse of [`voxel_size`](@ref) in 1/mm.
+
+It is defined separately from `voxel_size` to avoid unnecessary divisions in any constraint.
+"""
+function inverse_voxel_size end
 
-function qval_square(bb; kwargs...)
-    vec = qvec(bb; kwargs...)
+"""
+    qval_square(gradient)
+
+Square of [`qval`](@ref) in rad/um.
+
+It is defined separately from `qval` to avoid unnecessary square root in any constraint.
+"""
+function qval_square(bb, args...; kwargs...)
+    vec = qvec(bb, args...; kwargs...)
     return vec[1]^2 + vec[2]^2 + vec[3]^2
 end
 qval(bb; kwargs...) = sqrt(qval_square(bb))
 
+alternative_variables = Dict(
+    qval => (qval_square, n->n^2, sqrt, false),
+    slice_thickness => (inverse_slice_thickness, inv, inv, true),
+    bandwidth => (inverse_bandwidth, inv, inv, true),
+    fov => (inverse_fov, inv, inv, true),
+    voxel_size => (inverse_voxel_size, inv, inv, true),
+)
 
 
 # These functions are more fully defined in building_blocks.jl
diff --git a/src/wait.jl b/src/wait.jl
index 72360ed..c3b4a68 100644
--- a/src/wait.jl
+++ b/src/wait.jl
@@ -38,9 +38,6 @@ Converts object into a [`WaitBlock`](@ref).
 """
 to_block(time::Union{VariableType, Symbol, Nothing, Val{:min}, Val{:max}}) = WaitBlock(time)
 
-
-variables(::Type{WaitBlock}) = [duration]
-
 duration(wb::WaitBlock) = wb.duration
 
 end
\ No newline at end of file
-- 
GitLab