From d165e610fde1eb59d0e988a9e425f49836a953c3 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Wed, 31 Jan 2024 18:35:06 +0000
Subject: [PATCH] Add inverse_slice_thickness and better support
 inverse_bandwidth

---
 src/building_blocks.jl                 | 12 ++++++++++++
 src/overlapping/trapezoid_gradients.jl |  4 ++--
 src/pulses/constant_pulses.jl          |  5 ++---
 src/pulses/sinc_pulses.jl              |  5 ++---
 src/variables.jl                       |  4 ++++
 5 files changed, 22 insertions(+), 8 deletions(-)

diff --git a/src/building_blocks.jl b/src/building_blocks.jl
index 9621283..ffcfc85 100644
--- a/src/building_blocks.jl
+++ b/src/building_blocks.jl
@@ -262,9 +262,21 @@ If set to `:min` or `:max`, minimising or maximising this function will be added
 """
 function set_simple_constraints!(model::Model, block::BuildingBlock, kwargs)
     to_funcs = Dict(nameof(fn) => fn for fn in variables(block))
+
+    invert_value(value::VariableType) = 1 / value
+    invert_value(value::Symbol) = invert_value(Val(value))
+    invert_value(::Val{:min}) = Val(:max)
+    invert_value(::Val{:max}) = Val(:min)
+    invert_value(value::AbstractVector) = invert_value.(value)
+    invert_value(value) = value
+
     for (key, value) in kwargs
         if key == :qval
             apply_simple_constraint!(model, qval_square(block), value isa VariableType ? value^2 : value)
+        elseif key == :slice_thickness && :inverse_slice_thickness in keys(to_funcs)
+            apply_simple_constraint!(model, inverse_slice_thickness(block), invert_value(value))
+        elseif key == :bandwidth && :inverse_bandwidth in keys(to_funcs)
+            apply_simple_constraint!(model, inverse_bandwidth(block), invert_value(value))
         else
             apply_simple_constraint!(model, to_funcs[key](block), value)
         end
diff --git a/src/overlapping/trapezoid_gradients.jl b/src/overlapping/trapezoid_gradients.jl
index 875219b..7a5621c 100644
--- a/src/overlapping/trapezoid_gradients.jl
+++ b/src/overlapping/trapezoid_gradients.jl
@@ -5,7 +5,7 @@ module TrapezoidGradients
 
 import JuMP: @constraint, @variable, Model, VariableRef, owner_model, value
 import StaticArrays: SVector
-import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType, slice_thickness
+import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType, inverse_slice_thickness
 import ...BuildingBlocks: duration, set_simple_constraints!, fixed
 import ...BuildSequences: @global_model_constructor
 import ...Gradients: ChangingGradientBlock, ConstantGradientBlock
@@ -135,7 +135,7 @@ inverse_slice_thickness(g::TrapezoidGradient) = isnothing(g.pulse) ? nothing : i
 function variables(tg::TrapezoidGradient) 
     list = [qvec, δ, gradient_strength, duration, rise_time, flat_time]
     if !isnothing(tg.pulse)
-        push!(list, slice_thickness)
+        push!(list, inverse_slice_thickness)
     end
     return list
 end
diff --git a/src/pulses/constant_pulses.jl b/src/pulses/constant_pulses.jl
index f6b8ace..c6e8b83 100644
--- a/src/pulses/constant_pulses.jl
+++ b/src/pulses/constant_pulses.jl
@@ -1,7 +1,7 @@
 module ConstantPulses
 import JuMP: VariableRef, @constraint, @variable, value, Model
 import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!, fixed
-import ...Variables: variables, get_free_variable, flip_angle, phase, amplitude, frequency, bandwidth, start_time, end_time, VariableType, duration, effective_time, inverse_bandwidth
+import ...Variables: variables, get_free_variable, flip_angle, phase, amplitude, frequency, start_time, end_time, VariableType, duration, effective_time, inverse_bandwidth
 import ...BuildSequences: @global_model_constructor
 import ..FixedPulses: FixedPulse
 
@@ -43,10 +43,9 @@ phase(pulse::ConstantPulse) = pulse.phase
 frequency(pulse::ConstantPulse) = pulse.frequency
 flip_angle(pulse::ConstantPulse) = amplitude(pulse) * duration(pulse) * 360
 inverse_bandwidth(pulse::ConstantPulse) = duration(pulse) / 3.79098854
-bandwidth(pulse::ConstantPulse) = 1 / inverse_bandwidth(pulse)
 effective_time(pulse::ConstantPulse) = duration(pulse) / 2
 
-variables(::Type{<:ConstantPulse}) = [amplitude, duration, phase, frequency, flip_angle, bandwidth, inverse_bandwidth]
+variables(::Type{<:ConstantPulse}) = [amplitude, duration, phase, frequency, flip_angle, inverse_bandwidth]
 
 function fixed(block::ConstantPulse)
     d = value(duration(block))
diff --git a/src/pulses/sinc_pulses.jl b/src/pulses/sinc_pulses.jl
index 15263c9..b6d8e9f 100644
--- a/src/pulses/sinc_pulses.jl
+++ b/src/pulses/sinc_pulses.jl
@@ -4,7 +4,7 @@ import JuMP: VariableRef, @constraint, @variable, value, Model
 import QuadGK: quadgk
 import Polynomials: fit, Polynomial
 import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!, fixed
-import ...Variables: flip_angle, phase, amplitude, frequency, bandwidth, VariableType, variables, get_free_variable, duration, effective_time, inverse_bandwidth
+import ...Variables: flip_angle, phase, amplitude, frequency, VariableType, variables, get_free_variable, duration, effective_time, inverse_bandwidth
 import ...BuildSequences: @global_model_constructor
 import ..FixedPulses: FixedPulse
 
@@ -100,8 +100,7 @@ frequency(pulse::SincPulse) = pulse.frequency
 flip_angle(pulse::SincPulse) = (pulse.nlobe_integral(N_left(pulse)) + pulse.nlobe_integral(N_right(pulse))) * amplitude(pulse) * lobe_duration(pulse) * 360
 lobe_duration(pulse::SincPulse) = pulse.lobe_duration
 inverse_bandwidth(pulse::SincPulse) = lobe_duration(pulse)
-bandwidth(pulse::SincPulse) = 1 / inverse_bandwidth(pulse)
-variables(::Type{<:SincPulse}) = [amplitude, N_left, N_right, duration, phase, frequency, flip_angle, lobe_duration, bandwidth]
+variables(::Type{<:SincPulse}) = [amplitude, N_left, N_right, duration, phase, frequency, flip_angle, lobe_duration, inverse_bandwidth]
 effective_time(pulse::SincPulse) = N_left(pulse) * lobe_duration(pulse)
 
 function fixed(block::SincPulse)
diff --git a/src/variables.jl b/src/variables.jl
index afdc423..44985aa 100644
--- a/src/variables.jl
+++ b/src/variables.jl
@@ -20,6 +20,7 @@ all_variables_symbols = [
         :N_left => "The number of zero crossings of the RF pulse before the main peak",
         :N_right => "The number of zero crossings of the RF pulse after the main peak",
         :slice_thickness => "Slice thickness of an RF pulse that is active during a gradient.",
+        :inverse_bandwidth => "Inverse of the [`slice_thickness`](@ref) in 1/um.",
     ],
 
     # gradients
@@ -61,6 +62,9 @@ variables() = [values(symbol_to_func)...]
 
 
 # Some universal truths
+slice_thickness(bb) = inv(inverse_slice_thickness(bb))
+bandwidth(bb) = inv(inverse_bandwidth(bb))
+
 function qval_square(bb; kwargs...)
     vec = qvec(bb; kwargs...)
     return vec[1]^2 + vec[2]^2 + vec[3]^2
-- 
GitLab