From 2e71567a3a960b3df8c6b9e9f427c1a155711004 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Mon, 12 Feb 2024 16:47:07 +0000
Subject: [PATCH] Better constructors

---
 .../changing_gradient_blocks.jl               | 21 ++++++--------
 .../constant_gradient_blocks.jl               | 17 ++++-------
 .../gradient_waveforms/gradient_waveforms.jl  |  6 ++--
 .../gradient_waveforms/no_gradient_blocks.jl  | 28 ++++++++++---------
 4 files changed, 32 insertions(+), 40 deletions(-)

diff --git a/src/components/gradient_waveforms/changing_gradient_blocks.jl b/src/components/gradient_waveforms/changing_gradient_blocks.jl
index 6085549..14abf70 100644
--- a/src/components/gradient_waveforms/changing_gradient_blocks.jl
+++ b/src/components/gradient_waveforms/changing_gradient_blocks.jl
@@ -4,28 +4,23 @@ import ....Variables: VariableType, duration, qvec, bmat_gradient, gradient_stre
 import ...AbstractTypes: GradientWaveform
 
 
-abstract type ChangingGradient{N} <: GradientWaveform{N} end
-
 """
-    ChangingGradient1D(grad1, slew_rate, duration)
+    ChangingGradient(grad1, slew_rate, duration)
 
-Underlying type for any linearly changing part in a 1D gradient waveform.
+Underlying type for any linearly changing part in a 1D or 3D gradient waveform (depending on whether `grad1` and `slew_rate` are a scalar or a vector)
 
 Usually, you do not want to create this object directly, use a `BuildingBlock` instead.
 """
+abstract type ChangingGradient{N} <: GradientWaveform{N} end
+(::Type{ChangingGradient})(grad1::VariableType, slew_rate::VariableType, duration::VariableType) = ChangingGradient1D(grad1, slew_rate, duration)
+(::Type{ChangingGradient})(grad1::AbstractVector, slew_rate::AbstractVector, duration::VariableType) = ChangingGradient3D(grad1, slew_rate, duration)
+
 struct ChangingGradient1D <: ChangingGradient{1}
     gradient_strength_start :: VariableType
     slew_rate :: VariableType
     duration :: VariableType
 end
 
-"""
-    ChangingGradient3D(grad1, slew_rate, duration)
-
-Underlying type for any linearly changing part in a 3D gradient waveform.
-
-Usually, you do not want to create this object directly, use a `BuildingBlock` instead.
-"""
 struct ChangingGradient3D <: ChangingGradient{3}
     gradient_strength_start :: SVector{3, <:VariableType}
     slew_rate :: SVector{3, <:VariableType}
@@ -44,7 +39,7 @@ qvec(cgb::ChangingGradient) = (grad_start(cgb) .+ grad_end(cgb)) .* (duration(cg
 _mult(g1::VariableType, g2::VariableType) = g1 * g2
 _mult(g1::AbstractVector, g2::AbstractVector) = g1 .* permutedims(g2)
 
-function bmat_gradient(cgb::ChangingGradientBlock, qstart)
+function bmat_gradient(cgb::ChangingGradient, qstart)
     # grad = (g1 * (duration - t) + g2 * t) / duration
     #      = g1 + (g2 - g1) * t / duration
     # q = qstart + g1 * t + (g2 - g1) * t^2 / (2 * duration)
@@ -62,7 +57,7 @@ function bmat_gradient(cgb::ChangingGradientBlock, qstart)
     )
 end
 
-function bmat_gradient(cgb::ChangingGradientBlock)
+function bmat_gradient(cgb::ChangingGradient)
     diff = slew_rate(cgb) .* duration(cgb)
     return (2Ï€)^2 .* (
         _mult(grad_start(cgb), grad_start(cgb)) ./ 3 .+
diff --git a/src/components/gradient_waveforms/constant_gradient_blocks.jl b/src/components/gradient_waveforms/constant_gradient_blocks.jl
index 5efffb4..65ac57b 100644
--- a/src/components/gradient_waveforms/constant_gradient_blocks.jl
+++ b/src/components/gradient_waveforms/constant_gradient_blocks.jl
@@ -3,27 +3,22 @@ import StaticArrays: SVector
 import ....Variables: VariableType, duration, qvec, bmat_gradient, gradient_strength, slew_rate, get_free_variable
 import ...AbstractTypes: GradientWaveform
 
-abstract type ConstantGradient{N} <: GradientWaveform{N} end
-
 """
-    ConstantGradient1D(gradient_strength, duration)
+    ConstantGradient(gradient_strength, duration)
 
-Underlying type for any flat part in a 1D gradient waveform.
+Underlying type for any flat part in a 1D or 3D gradient waveform (depending on whether `gradient_strength` is a scalar or a vector).
 
 Usually, you do not want to create this object directly, use a `BuildingBlock` instead.
 """
+abstract type ConstantGradient{N} <: GradientWaveform{N} end
+(::Type{ConstantGradient})(grad1::VariableType, duration::VariableType) = ConstantGradient1D(grad1, duration)
+(::Type{ConstantGradient})(grad1::AbstractVector, duration::VariableType) = ConstantGradient3D(grad1, duration)
+
 struct ConstantGradient1D <: ConstantGradient{1}
     gradient_strength :: VariableType
     duration :: VariableType
 end
 
-"""
-    ConstantGradient1D(gradient_strength, duration)
-
-Underlying type for any flat part in a 3D gradient waveform.
-
-Usually, you do not want to create this object directly, use a `BuildingBlock` instead.
-"""
 struct ConstantGradient3D <: ConstanGradient{3}
     gradient_strength :: SVector{3, <:VariableType}
     duration :: VariableType
diff --git a/src/components/gradient_waveforms/gradient_waveforms.jl b/src/components/gradient_waveforms/gradient_waveforms.jl
index c3b432d..7e632e9 100644
--- a/src/components/gradient_waveforms/gradient_waveforms.jl
+++ b/src/components/gradient_waveforms/gradient_waveforms.jl
@@ -22,8 +22,8 @@ include("no_gradient_blocks.jl")
 
 
 import ..AbstractTypes: GradientWaveform
-import .NoGradientBlocks: NoGradientBlock
-import .ChangingGradientBlocks: ChangingGradientBlock, split_gradient
-import .ConstantGradientBlocks: ConstantGradientBlock
+import .NoGradientBlocks: NoGradient
+import .ChangingGradientBlocks: ChangingGradient, split_gradient
+import .ConstantGradientBlocks: ConstantGradient
 
 end
\ No newline at end of file
diff --git a/src/components/gradient_waveforms/no_gradient_blocks.jl b/src/components/gradient_waveforms/no_gradient_blocks.jl
index 3552f22..d11996b 100644
--- a/src/components/gradient_waveforms/no_gradient_blocks.jl
+++ b/src/components/gradient_waveforms/no_gradient_blocks.jl
@@ -5,15 +5,17 @@ import ...AbstractTypes: GradientWaveform
 import ..ChangingGradientBlocks: split_gradient
 
 """
-    NoGradientBlock(duration)
+    NoGradient{N}(duration)
 
 Part of a gradient waveform when there is no gradient active.
 
+`N` needs to be set to 1 if `orientation` is fixed in the gradient waveform or 3 otherwise.
+
 Usually, you do not want to create this object directly, use a `BuildingBlock` instead.
 """
-struct NoGradientBlock{N} <: GradientWaveform{N}
+struct NoGradient{N} <: GradientWaveform{N}
     duration :: VariableType
-    function NoGradientBlock{N}(duration)
+    function NoGradient{N}(duration)
         if !(N in (1, 3))
             error("Dimensionality of the gradient should be 1 or 3, not $N")
         end
@@ -21,23 +23,23 @@ struct NoGradientBlock{N} <: GradientWaveform{N}
     end
 end
 
-duration(ngb::NoGradientBlock) = duration(ngb)
+duration(ngb::NoGradient) = duration(ngb)
 for func in (:qvec, :gradient_strength, :slew_rate)
-    @eval $func(::NoGradientBlock{1}) = 0.
-    @eval $func(::NoGradientBlock{3}) = zero(SVector{3, Float64})
+    @eval $func(::NoGradient{1}) = 0.
+    @eval $func(::NoGradient{3}) = zero(SVector{3, Float64})
 end
 
-bmat_gradient(::NoGradientBlock{1}) = 0.
-bmat_gradient(::NoGradientBlock{3}) = zero(SMatrix{3, 3, Float64, 9})
+bmat_gradient(::NoGradient{1}) = 0.
+bmat_gradient(::NoGradient{3}) = zero(SMatrix{3, 3, Float64, 9})
 
-bmat_gradient(::NoGradientBlock, ) = 0.
-bmat_gradient(ngb::NoGradientBlock{1}, qstart::VariableType) = qstart^2 * duration(ngb)
-bmat_gradient(ngb::NoGradientBlock{3}, qstart::AbstractVector{<:VariableType}) = @. qstart * permutedims(qstart) * duration(ngb)
+bmat_gradient(::NoGradient, ) = 0.
+bmat_gradient(ngb::NoGradient{1}, qstart::VariableType) = qstart^2 * duration(ngb)
+bmat_gradient(ngb::NoGradient{3}, qstart::AbstractVector{<:VariableType}) = @. qstart * permutedims(qstart) * duration(ngb)
 
-function split_gradient(ngb::NoGradientBlock, times::VariableType...)
+function split_gradient(ngb::NoGradient, times::VariableType...)
     durations = [times[1], [t[2] - t[1] for t in zip(times[1:end-1], times[2:end])]..., duration(ngb) - times[end]]
     @assert all(durations >= 0.)
-    return [NoGradientBlock(d) for d in durations]
+    return [NoGradient(d) for d in durations]
 end
 
 end
\ No newline at end of file
-- 
GitLab