From e3922b61f40396a61e16a6dc0a30286c4a6ec045 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Thu, 15 Feb 2024 18:50:17 +0000
Subject: [PATCH] Ensure vectors are 3D when computing bmat_gradient

---
 .../changing_gradient_blocks.jl                   | 15 ++++++++++-----
 .../constant_gradient_blocks.jl                   |  9 ++++++---
 2 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/src/components/gradient_waveforms/changing_gradient_blocks.jl b/src/components/gradient_waveforms/changing_gradient_blocks.jl
index 7c7a2a8..d76c5e5 100644
--- a/src/components/gradient_waveforms/changing_gradient_blocks.jl
+++ b/src/components/gradient_waveforms/changing_gradient_blocks.jl
@@ -44,7 +44,10 @@ qval(cgb::ChangingGradient) = (grad_start(cgb) .+ grad_end(cgb)) .* (duration(cg
 _mult(g1::VariableType, g2::VariableType) = g1 * g2
 _mult(g1::AbstractVector, g2::AbstractVector) = g1 .* permutedims(g2)
 
-function bmat_gradient(cgb::ChangingGradient, qstart)
+to_vec(cgb::ChangingGradient1D, g::VariableType) = cgb.orientation .* g
+to_vec(::ChangingGradient3D, g::AbstractVector) = g
+
+function bmat_gradient(cgb::ChangingGradient, qstart::AbstractVector)
     # grad = (g1 * (duration - t) + g2 * t) / duration
     #      = g1 + (g2 - g1) * t / duration
     # q = qstart + g1 * t + (g2 - g1) * t^2 / (2 * duration)
@@ -55,18 +58,20 @@ function bmat_gradient(cgb::ChangingGradient, qstart)
     #   g1^2 * duration^3 / 3 +
     #   g1 * (g2 - g1) * duration^3 / 4 +
     #   (g2 - g1)^2 * duration^3 / 10
+    grad_aver = to_vec(cgb, 2 .* grad_start(cgb) .+ grad_end(cgb))
     return (
         _mult(qstart, qstart) .* duration(cgb) .+
-        duration(cgb)^2 .* _mult(qstart, 2 .* grad_start(cgb) .+ grad_end(cgb)) .* 2Ï€ ./ 3 .+
+        duration(cgb)^2 .* _mult(qstart, grad_aver) .* 2Ï€ ./ 3 .+
         bmat_gradient(cgb)
     )
 end
 
 function bmat_gradient(cgb::ChangingGradient)
-    diff = slew_rate(cgb) .* duration(cgb)
+    gs = to_vec(cgb, grad_start(cgb))
+    diff = to_vec(cgb, slew_rate(cgb) .* duration(cgb))
     return (2Ï€)^2 .* (
-        _mult(grad_start(cgb), grad_start(cgb)) ./ 3 .+
-        _mult(grad_start(cgb), diff) ./ 4 .+
+        _mult(gs, gs) ./ 3 .+
+        _mult(gs, diff) ./ 4 .+
         _mult(diff, diff) ./ 10
     ) .* duration(cgb)^3
 end
diff --git a/src/components/gradient_waveforms/constant_gradient_blocks.jl b/src/components/gradient_waveforms/constant_gradient_blocks.jl
index a5fb178..508cbbf 100644
--- a/src/components/gradient_waveforms/constant_gradient_blocks.jl
+++ b/src/components/gradient_waveforms/constant_gradient_blocks.jl
@@ -40,17 +40,20 @@ qval(cgb::ConstantGradient3D) = @. duration(cgb) * gradient_strength(cgb) * 2Ï€
 _mult(g1::VariableType, g2::VariableType) = g1 * g2
 _mult(g1::AbstractVector, g2::AbstractVector) = g1 .* permutedims(g2)
 
+to_vec(cgb::ConstantGradient1D, g::VariableType) = cgb.orientation .* g
+to_vec(::ConstantGradient3D, g::AbstractVector) = g
+
 function bmat_gradient(cgb::ConstantGradient)
-    grad = 2Ï€ .* gradient_strength(cgb)
+    grad = to_vec(cgb, 2Ï€ .* gradient_strength(cgb))
     return _mult(grad, grad) .* duration(cgb)^3 ./3
 end
 
-function bmat_gradient(cgb::ConstantGradient, qstart)
+function bmat_gradient(cgb::ConstantGradient, qstart::AbstractVector)
     # \int dt (qstart + t * grad)^2 = 
     #   qstart^2 * duration +
     #   qstart * grad * duration^2 +
     #   grad * grad * duration^3 / 3 +
-    grad = 2Ï€ .* gradient_strength(cgb)
+    grad = to_vec(cgb, 2Ï€ .* gradient_strength(cgb))
     return (
         _mult(qstart, qstart) .* duration(cgb) .+
         _mult(qstart, grad) .* duration(cgb)^2 .+
-- 
GitLab