From 5a376c1669b1007858c806722da86c5233fd7633 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Sun, 28 Jan 2024 21:06:22 +0000
Subject: [PATCH] Update instant gradients

---
 src/gradients/instant_gradients.jl | 58 +++++++++++++++++++++++-------
 1 file changed, 46 insertions(+), 12 deletions(-)

diff --git a/src/gradients/instant_gradients.jl b/src/gradients/instant_gradients.jl
index 435a633..58c96bb 100644
--- a/src/gradients/instant_gradients.jl
+++ b/src/gradients/instant_gradients.jl
@@ -1,42 +1,76 @@
 module InstantGradients
-import JuMP: @constraint, @variable, Model, owner_model
-import ...Variables: qval, bval, start_time, duration, variables, get_free_variable, VariableType
+import JuMP: @constraint, @variable, Model, owner_model, AbstractJuMPScalar
+import ...Variables: qvec, bmat, duration, variables, get_free_variable, VariableType
 import ...BuildingBlocks: GradientBlock, fixed
 import ...BuildSequences: @global_model_constructor
 import ..FixedGradients: FixedInstantGradient
 
 """
-    InstantGradientBlock(; orientation=:bvec, qval=nothing)
+    InstantGradientBlock(; orientation=nothing, qval=nothing, qvec=nothing, rotate=nothing, scale=nothing)
 
 Defines an instantaneous gradient.
 
 ## Parameters
-- `orientation` sets the gradient orienation. Can be set to a vector for a fixed orientation. Alternatively, can be set to :bvec (default) to rotate with the user-provided `bvecs` or to :neg_bvec to always be the reverse of the `bvecs`.
+- `orientation` sets the gradient orienation (ignored if `qvec` is set). Can be set to a vector for a fixed orientation. Otherwise the orientation will be aligned with the `rotate` (if set) or fully free (if `rotate` is nothing).
+- `rotate`: with which user-set parameter will this gradient be rotated (e.g., :bvec). Default is no rotation.
+- `scale`: with which user-set parameter will this gradient be scaled (e.g., :bval). Default is no scaling.
 
 ## Variables
+- [`qvec`](@ref): Spatial scale and direction on which spins will be dephased due to this pulsed gradient in rad/um.
 - [`qval`](@ref): Spatial scale on which spins will be dephased due to this pulsed gradient in rad/um.
 """
 struct InstantGradientBlock <: GradientBlock
     model::Model
-    orientation :: Any
-    qval :: VariableType
+    qvec :: SVector{3, VariableType}
+    rotate :: Union{Nothing, Symbol}
+    scale :: Union{Nothing, Symbol}
 end
 
 @global_model_constructor InstantGradientBlock
 
-function InstantGradientBlock(model::Model; orientation=:bvec, qval=nothing)
+function InstantGradientBlock(model::Model; orientation=nothing, qval=nothing, qvec=nothing, rotate=nothing, scale=nothing)
+    used_qval = false
+    if isnothing(qvec)
+        if isnothing(orientation)
+            if isnothing(rotate)
+                qvec = SVector{3}(
+                    get_free_variable(nothing),
+                    get_free_variable(nothing),
+                    get_free_variable(nothing),
+                )
+            else
+                qvec = SVector{3}(
+                    get_free_variable(qval),
+                    0.,
+                    0.,
+                )
+                used_qval = true
+            end
+        else
+            qval = get_free_variable(qval)
+            qvec = (orientation ./ norm(orientation)) * qval
+            used_qval = true
+        end
+    end
+
     res = InstantGradientBlock(
         model,
-        orientation,
-        get_free_variable(model, qval),
+        qvec,
+        rotate,
+        scale
     )
-    @constraint model model.qval >= 0
+    if !used_qval && !isnothing(qval)
+        @constraint model qval_sqr(res) == qval^2
+    end
+    if qval isa AbstractJuMPScalar
+        @constraint model qval >= 0
+    end
     return res
 end
 
 
-qval(instant::InstantGradientBlock) = instant.qval
-bval(instant::InstantGradientBlock) = 0.
+qvec(instant::InstantGradientBlock) = instant.qvec
+bmat(instant::InstantGradientBlock, qstart=nothing) = zeros(3, 3)
 duration(instant::InstantGradientBlock) = 0.
 variables(::Type{<:InstantGradientBlock}) = [qval]
 
-- 
GitLab