From ab9ea84c3de22a2a5cb3f209d7af1c7c4f9e7eb1 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Thu, 15 Feb 2024 18:38:34 +0000
Subject: [PATCH] Pass on orientation and group to GradientWaveform components

---
 src/all_building_blocks/building_blocks.jl      | 13 +++++++------
 src/all_building_blocks/spoilt_slice_selects.jl | 17 +++++++++--------
 src/all_building_blocks/trapezoids.jl           | 13 ++++++++-----
 src/components/instant_gradients.jl             |  4 ++--
 4 files changed, 26 insertions(+), 21 deletions(-)

diff --git a/src/all_building_blocks/building_blocks.jl b/src/all_building_blocks/building_blocks.jl
index 79c4b94..92ae8e2 100644
--- a/src/all_building_blocks/building_blocks.jl
+++ b/src/all_building_blocks/building_blocks.jl
@@ -5,7 +5,7 @@ import ...Variables: VariableType, duration, make_generic, get_pulse, get_readou
 import ...Components: BaseComponent, DelayedEvent, RFPulseComponent, ReadoutComponent
 
 """
-    BuildingBlock(waveform, events; duration=nothing, orientation=nothing)
+    BuildingBlock(waveform, events; duration=nothing, orientation=nothing, group)
 
 Generic [`BaseBuildingBlock`](@ref) that can capture any overlapping gradients, RF pulses, and/or readouts.
 The gradients cannot contain any free variables.
@@ -15,6 +15,7 @@ The gradients cannot contain any free variables.
 - `events`: Sequence of 2-element tuples with (index, pulse/readout). The start time of the pulse/readout at the start of the gradient waveform element with index `index` (use [`DelayedEvent`](@ref) to make this earlier or later).
 - `duration`: duration of this `BuildingBlock`. If not set then it will be assumed to be the time of the last element in `waveform`.
 - `orientation`: orientation of the gradients in the waveform. If not set, then the full gradient vector should be given explicitly.
+- `group`: group of the gradient waveform
 """
 struct BuildingBlock <: BaseBuildingBlock
     parts :: Vector{<:BaseComponent}
@@ -27,10 +28,10 @@ struct BuildingBlock <: BaseBuildingBlock
     end
 end
 
-function BuildingBlock(waveform::AbstractVector, events::AbstractVector; duration=nothing, orientation=nothing)
+function BuildingBlock(waveform::AbstractVector, events::AbstractVector; duration=nothing, orientation=nothing, group=nothing)
     events = Any[events...]
     waveform = Any[waveform...]
-    ndim = isnothing(orientat) ? 1 : 3
+    ndim = isnothing(orientation) ? 1 : 3
     zero_grad = isnothing(orientation) ? zeros(3) : 0.
     if length(waveform) == 0 || waveform[1][1] > 0.
         pushfirst!(waveform, (0., zero_grad))
@@ -48,11 +49,11 @@ function BuildingBlock(waveform::AbstractVector, events::AbstractVector; duratio
     for (index_grad, ((prev_time, prev_grad), (time, grad))) in enumerate(zip(waveform[1:end-1], waveform[2:end]))
         duration = time - prev_time
         if norm(prev_grad) <= 1e-12 && norm(grad) <= 1e-12
-            push!(components, NoGradientBlock{ndim}(duration))
+            push!(components, NoGradient{ndim}(duration))
         elseif norm(prev_grad) ≈ norm(grad)
-            push!(components, ConstantGradientBlock(prev_grad, duration))
+            push!(components, ConstantGradient(prev_grad, orientation, duration, group))
         else
-            push!(components, ChangingGradientBlock(prev_grad, (grad .- prev_grad) ./ duration, duration))
+            push!(components, ChangingGradient(prev_grad, (grad .- prev_grad) ./ duration, orientation, duration, group))
         end
         while length(events) > 0 && index_grad == events[1][1]
             (_, event) = popfirst!(events)
diff --git a/src/all_building_blocks/spoilt_slice_selects.jl b/src/all_building_blocks/spoilt_slice_selects.jl
index 6cbccea..272cfca 100644
--- a/src/all_building_blocks/spoilt_slice_selects.jl
+++ b/src/all_building_blocks/spoilt_slice_selects.jl
@@ -4,7 +4,7 @@ import LinearAlgebra: norm
 import StaticArrays: SVector
 import JuMP: @constraint, @objective, objective_function
 import ...BuildSequences: global_model, global_scanner
-import ...Variables: VariableType, duration, rise_time, flat_time, effective_time, qval, gradient_strength, slew_rate, inverse_slice_thickness, get_free_variable, get_pulse, set_simple_constraints!
+import ...Variables: VariableType, duration, rise_time, flat_time, effective_time, qval, gradient_strength, slew_rate, inverse_slice_thickness, get_free_variable, get_pulse, set_simple_constraints!, gradient_orientation
 import ...Components: ChangingGradient, ConstantGradient, RFPulseComponent
 import ..BaseBuildingBlocks: BaseBuildingBlock
 
@@ -92,18 +92,19 @@ function SpoiltSliceSelect(pulse::RFPulseComponent; orientation=[0, 0, 1], group
     return res
 end
 
+gradient_orientation(spoilt::SpoiltSliceSelect) = spoilt.orientation
 duration_trap1(spoilt::SpoiltSliceSelect) = 2 * spoilt.rise_time1 + spoilt.flat_time1 - spoilt.diff_time
 duration_trap2(spoilt::SpoiltSliceSelect) = 2 * spoilt.fall_time2 + spoilt.flat_time2 - spoilt.diff_time
 
 Base.keys(::SpoiltSliceSelect) = Val.((:rise1, :flat1, :fall1, :flat_pulse, :pulse, :rise2, :flat2, :fall2))
-Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:rise1}) = ChangingGradient(0., slew_rate(spoilt), rise_time(spoilt)[1])
-Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:flat1}) = ConstantGradient(slew_rate(spoilt) * rise_time(spoilt)[1], flat_time(spoilt)[1])
-Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:fall1}) = ChangingGradient(slew_rate(spoilt) * rise_time(spoilt)[1], -slew_rate(spoilt), fall_time(spoilt)[1])
-Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:flat_pulse}) = ConstantGradient(slew_rate(spoilt) * spoilt.diff_time, duration(spoilt.pulse))
+Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:rise1}) = ChangingGradient(0., slew_rate(spoilt), gradient_orientation(spoilt), rise_time(spoilt)[1], spoilt.group)
+Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:flat1}) = ConstantGradient(slew_rate(spoilt) * rise_time(spoilt)[1], gradient_orientation(spoilt), flat_time(spoilt)[1], spoilt.group)
+Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:fall1}) = ChangingGradient(slew_rate(spoilt) * rise_time(spoilt)[1], -slew_rate(spoilt), gradient_orientation(spoilt), fall_time(spoilt)[1], spoilt.group)
+Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:flat_pulse}) = ConstantGradient(slew_rate(spoilt) * spoilt.diff_time, gradient_orientation(spoilt), duration(spoilt.pulse), spoilt.group)
 Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:pulse}) = spoilt.pulse
-Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:rise2}) = ChangingGradient(slew_rate(spoilt) * spoilt.diff_time, slew_rate(spoilt), rise_time(spoilt)[2])
-Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:flat2}) = ConstantGradient(slew_rate(spoilt) * fall_time(spoilt)[2], flat_time(spoilt)[2])
-Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:fall2}) = ChangingGradient(slew_rate(spoilt) * fall_time(spoilt)[2], -slew_rate(spoilt), fall_time(spoilt)[2])
+Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:rise2}) = ChangingGradient(slew_rate(spoilt) * spoilt.diff_time, slew_rate(spoilt), gradient_orientation(spoilt), rise_time(spoilt)[2], spoilt.group)
+Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:flat2}) = ConstantGradient(slew_rate(spoilt) * fall_time(spoilt)[2], gradient_orientation(spoilt), flat_time(spoilt)[2], spoilt.group)
+Base.getindex(spoilt::SpoiltSliceSelect, ::Val{:fall2}) = ChangingGradient(slew_rate(spoilt) * fall_time(spoilt)[2], -slew_rate(spoilt), gradient_orientation(spoilt), fall_time(spoilt)[2], spoilt.group)
 
 rise_time(spoilt::SpoiltSliceSelect) = (spoilt.rise_time1, spoilt.fall_time2 - spoilt.diff_time)
 flat_time(spoilt::SpoiltSliceSelect) = (spoilt.flat_time1, spoilt.flat_time2)
diff --git a/src/all_building_blocks/trapezoids.jl b/src/all_building_blocks/trapezoids.jl
index 6c69bf3..e66edbc 100644
--- a/src/all_building_blocks/trapezoids.jl
+++ b/src/all_building_blocks/trapezoids.jl
@@ -7,7 +7,7 @@ import JuMP: @constraint
 import StaticArrays: SVector
 import LinearAlgebra: norm
 import ...Variables: qval, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType, inverse_bandwidth, effective_time, qval_square, duration, set_simple_constraints!, scanner_constraints!, inverse_slice_thickness
-import ...Variables: Variables, all_variables_symbols, dwell_time, inverse_fov, inverse_voxel_size, fov, voxel_size, get_gradient, get_pulse, get_readout
+import ...Variables: Variables, all_variables_symbols, dwell_time, inverse_fov, inverse_voxel_size, fov, voxel_size, get_gradient, get_pulse, get_readout, gradient_orientation
 import ...BuildSequences: global_model
 import ...Components: ChangingGradient, ConstantGradient, RFPulseComponent, ADC
 import ..BaseBuildingBlocks: BaseBuildingBlock
@@ -31,7 +31,7 @@ abstract type BaseTrapezoid{N} <: BaseBuildingBlock end
 Defines a trapezoidal pulsed gradient
 
 ## Parameters
-- `orientation` sets the gradient orienation (completely free by default). Can be set to a vector for a fixed orientation.
+- `orientation` sets the gradient orientation (completely free by default). Can be set to a vector for a fixed orientation.
 - `group`: assign the trapezoidal gradient to a specific group. This group will be used to scale or rotate the gradients after optimisation.
 
 ## Variables
@@ -97,9 +97,12 @@ end
 
 Base.keys(::Trapezoid) = (Val(:rise), Val(:flat), Val(:fall))
 
-Base.getindex(pg::BaseTrapezoid{N}, ::Val{:rise}) where {N} = ChangingGradient(N == 3 ? zeros(3) : 0., slew_rate(pg), rise_time(pg))
-Base.getindex(pg::BaseTrapezoid, ::Val{:flat}) = ConstantGradient(gradient_strength(pg), flat_time(pg))
-Base.getindex(pg::BaseTrapezoid, ::Val{:fall}) = ChangingGradient(gradient_strength(pg), -slew_rate(pg), rise_time(pg))
+Base.getindex(pg::BaseTrapezoid{N}, ::Val{:rise}) where {N} = ChangingGradient(N == 3 ? zeros(3) : 0., slew_rate(pg), gradient_orientation(pg), rise_time(pg), pg.group)
+Base.getindex(pg::BaseTrapezoid, ::Val{:flat}) = ConstantGradient(gradient_strength(pg), gradient_orientation(pg), flat_time(pg), pg.group)
+Base.getindex(pg::BaseTrapezoid, ::Val{:fall}) = ChangingGradient(gradient_strength(pg), -slew_rate(pg), gradient_orientation(pg), rise_time(pg), pg.group)
+gradient_orientation(::BaseTrapezoid{3}) = nothing
+gradient_orientation(pg::BaseTrapezoid{1}) = gradient_orientation(get_gradient(pg))
+gradient_orientation(pg::Trapezoid{1}) = pg.orientation
 
 rise_time(pg::Trapezoid) = pg.rise_time
 flat_time(pg::Trapezoid) = pg.flat_time
diff --git a/src/components/instant_gradients.jl b/src/components/instant_gradients.jl
index 6c0da62..dadd65e 100644
--- a/src/components/instant_gradients.jl
+++ b/src/components/instant_gradients.jl
@@ -9,7 +9,7 @@ import ..AbstractTypes: EventComponent, GradientWaveform
 If the `orientation` is set an [`InstantGradient1D`](@ref) is returned, otherwise an [`InstantGradient3D`](@ref).
 
 ## Parameters
-- `orientation` sets the gradient orienation as a length-3 vector. If not set, the gradient can be in any direction.
+- `orientation` sets the gradient orientation as a length-3 vector. If not set, the gradient can be in any direction.
 - `group`: name of the group to which this gradient belongs (used for scaling and rotating).
 
 ## Variables
@@ -19,7 +19,7 @@ If the `orientation` is set an [`InstantGradient1D`](@ref) is returned, otherwis
 abstract type InstantGradient <: EventComponent end
 
 function (::Type{InstantGradient})(; orientation=nothing, group=nothing, qval=nothing, variables...)
-    if isnothing(orientaiton)
+    if isnothing(orientation)
         res = InstantGradient3D(get_free_variable.(qval), group)
     else
         res = InstantGradient1D(get_free_variable(qval), orientation, group)
-- 
GitLab