From 9f873240a57c04f7ab2caf9c4ad376bbec901305 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Sun, 28 Jan 2024 21:55:14 +0000
Subject: [PATCH] Swtich pulsedGradient to container

---
 src/gradients/instant_gradients.jl |  53 ++++++++------
 src/gradients/pulsed_gradients.jl  | 108 +++++++++++++++++------------
 2 files changed, 95 insertions(+), 66 deletions(-)

diff --git a/src/gradients/instant_gradients.jl b/src/gradients/instant_gradients.jl
index df8af2e..62cb576 100644
--- a/src/gradients/instant_gradients.jl
+++ b/src/gradients/instant_gradients.jl
@@ -11,7 +11,7 @@ import ..FixedGradients: FixedInstantGradient
 Defines an instantaneous gradient.
 
 ## Parameters
-- `orientation` sets the gradient orienation (ignored if `qvec` is set). Can be set to a vector for a fixed orientation. Otherwise the orientation will be aligned with the `rotate` (if set) or fully free (if `rotate` is nothing).
+- `orientation` sets the gradient orienation (ignored if `qvec` is set). Can be set to a vector for a fixed orientation. Otherwise the orientation will be aligned with the `rotate` (if set) or fully free (if `rotate` is nothing). Set to :flip to point in the inverse of the user-provided `rotate`.
 - `rotate`: with which user-set parameter will this gradient be rotated (e.g., :bvec). Default is no rotation.
 - `scale`: with which user-set parameter will this gradient be scaled (e.g., :bval). Default is no scaling.
 
@@ -28,31 +28,40 @@ end
 
 @global_model_constructor InstantGradientBlock
 
-function InstantGradientBlock(model::Model; orientation=nothing, qval=nothing, qvec=nothing, rotate=nothing, scale=nothing)
-    used_qval = false
-    if isnothing(qvec)
-        if isnothing(orientation)
-            if isnothing(rotate)
-                qvec = SVector{3}(
-                    get_free_variable(nothing),
-                    get_free_variable(nothing),
-                    get_free_variable(nothing),
-                )
-            else
-                qvec = SVector{3}(
-                    get_free_variable(qval),
-                    0.,
-                    0.,
-                )
-                used_qval = true
-            end
+function interpret_orientation(orientation, qval, qvec, will_rotate)
+    if !isnothing(qvec)
+        return (false, qvec)
+    end
+    if orientation == :flip
+        @assert will_rotate "setting `orientation=:flip` only makes sense if the `rotate` is set as well."
+        return (true, SVector{3}(
+            -get_free_variable(qval),
+            0.,
+            0.,
+        ))
+    end
+    if isnothing(orientation)
+        if !will_rotate
+            return (false, SVector{3}(
+                get_free_variable(nothing),
+                get_free_variable(nothing),
+                get_free_variable(nothing),
+            ))
         else
-            qval = get_free_variable(qval)
-            qvec = (orientation ./ norm(orientation)) * qval
-            used_qval = true
+            return (true, SVector{3}(
+                get_free_variable(qval),
+                0.,
+                0.,
+            ))
         end
+    else
+        qval = get_free_variable(qval)
+        return (true, (orientation ./ norm(orientation)) * qval)
     end
+end
 
+function InstantGradientBlock(model::Model; orientation=nothing, qval=nothing, qvec=nothing, rotate=nothing, scale=nothing)
+    (used_qval, qvec) = interpret_orientation(orientation, qval, qvec, !isnothing(rotate))
     res = InstantGradientBlock(
         model,
         qvec,
diff --git a/src/gradients/pulsed_gradients.jl b/src/gradients/pulsed_gradients.jl
index 1f5a846..0ac4681 100644
--- a/src/gradients/pulsed_gradients.jl
+++ b/src/gradients/pulsed_gradients.jl
@@ -6,9 +6,11 @@ module PulsedGradients
 import JuMP: @constraint, @variable, Model, VariableRef, owner_model, value
 import StaticArrays: SVector
 import ...Variables: qval, bval, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType
-import ...BuildingBlocks: GradientBlock, duration, set_simple_constraints!, fixed
+import ...BuildingBlocks: ContainerBlock, duration, set_simple_constraints!, fixed, start_time, get_children_blocks
 import ...BuildSequences: @global_model_constructor
 import ..FixedGradients: FixedGradient
+import ..ChangingGradientBlocks: ChangingGradientBlock
+import ..ConstantGradientBlocks: ConstantGradientBlock
 
 
 """
@@ -17,7 +19,9 @@ import ..FixedGradients: FixedGradient
 Defines a trapezoidal pulsed gradient
 
 ## Parameters
-- `orientation` sets the gradient orienation. Can be set to a vector for a fixed orientation. Alternatively, can be set to :bvec (default) to rotate with the user-provided `bvecs` or to :neg_bvec to always be the reverse of the `bvecs`.
+- `orientation` sets the gradient orienation (ignored if `qvec` is set). Can be set to a vector for a fixed orientation. Otherwise the orientation will be aligned with the `rotate` (if set) or fully free (if `rotate` is nothing). Set to :flip to point in the inverse of the user-provided `rotate`.
+- `rotate`: with which user-set parameter will this gradient be rotated (e.g., :bvec). Default is no rotation.
+- `scale`: with which user-set parameter will this gradient be scaled (e.g., :bval). Default is no scaling.
 
 ## Variables
 Variables can be set during construction or afterwards as an attribute.
@@ -33,74 +37,90 @@ If not set, they will be determined during the sequence optimisation.
 
 The [`bvalue`](@ref) can be constrained for multiple gradient pulses.
 """
-mutable struct PulsedGradient <: GradientBlock
+mutable struct PulsedGradient <: ContainerBlock
     model :: Model
-    orientation :: Any
-    slew_rate :: VariableType
-    rise_time :: VariableType
-    flat_time :: VariableType
+    rise :: ChangingGradientBlock
+    flat :: ConstantGradientBlock
+    fall :: ChangingGradientBlock
+    slew_rate_vec :: SVector{3, VariableType}
+    scaling :: Union{Nothing, VariableType}
 end
 
 @global_model_constructor PulsedGradient
 
-function PulsedGradient(model::Model; orientation=:bvec, slew_rate=nothing, rise_time=nothing, flat_time=nothing, kwargs...)
+function PulsedGradient(model::Model; orientation=nothing, rise_time=nothing, flat_time=nothing, rotate=nothing, scale=nothing, kwargs...)
+    if isnothing(orientation) && isnothing(rotate)
+        rate_1d = nothing
+        slew_rate = (
+            get_free_variable(model, nothing),
+            get_free_variable(model, nothing),
+            get_free_variable(model, nothing),
+        )
+    else
+        rate_1d = get_free_variable(model, nothing)
+        @constraint model rate_1d >= 0
+        if isnothing(orientation)
+            rate_1d = get_free_variable(model, nothing)
+            slew_rate = (
+                rate_1d,
+                0.,
+                0.,
+            )
+        elseif orientation == :flip
+            @assert !isnothing(rotate) "setting `orientation=:flip` only makes sense if the `rotate` is set as well."
+            slew_rate = (
+                -rate_1d,
+                0.,
+                0.,
+            )
+        else
+            slew_rate = rate_1d .* (orientation ./ norm(orientation))
+        end
+    end
+    rise_time = get_free_variable(model, rise_time)
+    grad_vec = slew_rate .* rise_time
+
     res = PulsedGradient(
         model,
-        orientation,
-        [get_free_variable(model, value) for value in (slew_rate, rise_time, flat_time)]...
+        ChangingGradientBlock(zeros(3), grad_vec, rise_time, rotate, scale)
+        ConstantGradientBlock(grad_vec, flat_time, rotate, scale)
+        ChangingGradientBlock(grad_vec, zeros(3), rise_time, rotate, scale)
+        slew_rate,
+        rate_1d,
     )
 
     set_simple_constraints!(model, res, kwargs)
     @constraint model res.flat_time >= 0
     @constraint model res.rise_time >= 0
-    @constraint model res.slew_rate >= 0
     return res
 end
 
 rise_time(pg::PulsedGradient) = pg.rise_time
 flat_time(pg::PulsedGradient) = pg.flat_time
-gradient_strength(g::PulsedGradient) = rise_time(g) * slew_rate(g)
-slew_rate(g::PulsedGradient) = g.slew_rate
+gradient_strength_vec(g::PulsedGradient) = rise_time(g) * slew_rate(g)
+gradient_strength(g::PulsedGradient) = isnothing(g.scaling) ? maximum(gradient_strength_vec(g)) : (res.scaling * rise_time(g))
+slew_rate_vec(g::PulsedGradient) = g.slew_rate_vec
+slew_rate(g::PulsedGradient) = isnothing(g.scaling) ? maximum(slew_rate_vec(g)) : res.scaling
 δ(g::PulsedGradient) = rise_time(g) + flat_time(g)
 duration(g::PulsedGradient) = 2 * rise_time(g) + flat_time(g)
-qval(g::PulsedGradient) = (g.orientation == :neg_bvec ? -1 : 1) * gradient_strength(g) * δ(g)
-
-
-function bval(g::PulsedGradient, qstart=0.)
-    tr = rise_time(g)
-    td = δ(g)
-    grad = gradient_strength(g)
-    return (
-        # b-value due to just the gradient
-        grad * (1//60 * tr^3 - 1//12 * tr^2 * td + 1//2 * tr * td^2 + 1//3 * td^3) + 
-        # b-value due to cross-term
-        qstart * grad * (td * (td + tr)) +
-        # b-value due to just `qstart`
-        (td + tr) * qstart^2
-    )
-end
 
-variables(::Type{<:PulsedGradient}) = [qval, δ, gradient_strength, duration, rise_time, flat_time, slew_rate]
+get_children_indices(::PulsedGradient) = (:rise, :flat, :fall)
+Base.get_index(pg::PulsedGradient, symbol::Symbol) = pg[Val(symbol)]
+Base.get_index(pg::PulsedGradient, ::Val{:rise}) = pg.rise
+Base.get_index(pg::PulsedGradient, ::Val{:flat}) = pg.flat
+Base.get_index(pg::PulsedGradient, ::Val{:fall}) = pg.fall
+
+
+variables(::Type{<:PulsedGradient}) = [qval, δ, gradient_strength_vec, duration, rise_time, flat_time]
 
 
 function fixed(block::PulsedGradient)
-    if block.orientation == :bvec
-        rotate = true
-        qvec = [value(qval(block)), 0., 0.]
-    elseif block.orientation == :neg_bvec
-        rotate = true
-        qvec = [-value(qval(block)), 0., 0.]
-    elseif block.orientation isa AbstractVector && size(block.orientation) == (3, )
-        rotate = false
-        qvec = block.orientation .* (value(qval(block)) / norm(block.orientation))
-    else
-        error("Gradient orientation should be :bvec, :neg_bvec or a length-3 vector, not $(block.orienation)")
-    end
+    grad_vec = value.(gradient_strength_vec(block))
     t_rise = value(rise_time(block))
     t_d = value(δ(block))
-    return FixedBlock(
+    return FixedGradient(
         [0., t_rise, t_d, td + t_rise],
-        [zeros(3), qvec, qvec, zeros(3)]; 
+        [zeros(3), grad_vec, grad_vec, zeros(3)]; 
         rotate=rotate
     )
 end
-- 
GitLab