From ffbbf63709b1605efcda6d3a86419022c00b002b Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Fri, 26 Apr 2024 12:05:51 +0100
Subject: [PATCH] Fix gradient orientation extraction

---
 src/components/abstract_types.jl    |  3 ++-
 src/components/instant_gradients.jl |  6 +++--
 src/containers/base_sequences.jl    |  8 +++++-
 src/containers/building_blocks.jl   | 26 +++++++++++++++++--
 src/variables.jl                    | 12 ++-------
 test/runtests.jl                    | 10 +++++---
 test/test_post_hoc.jl               | 40 +++++++++++++++++++++++++++++
 7 files changed, 85 insertions(+), 20 deletions(-)
 create mode 100644 test/test_post_hoc.jl

diff --git a/src/components/abstract_types.jl b/src/components/abstract_types.jl
index b366b5c..20fd27d 100644
--- a/src/components/abstract_types.jl
+++ b/src/components/abstract_types.jl
@@ -1,5 +1,5 @@
 module AbstractTypes
-import ...Variables: AbstractBlock, duration, variables, effective_time, adjustable
+import ...Variables: AbstractBlock, duration, variables, effective_time, adjustable, gradient_orientation
 
 """
 Super-type for all individual components that form an MRI sequence (i.e., RF pulse, gradient waveform, or readout event).
@@ -52,5 +52,6 @@ split_timestep(comp_tuple::Tuple{<:Number, <:EventComponent}, precision::Number)
 
 adjustable(::RFPulseComponent) = :pulse
 adjustable(::GradientWaveform) = :gradient
+gradient_orientation(gw::GradientWaveform{1}) = gw.orientation
 
 end
\ No newline at end of file
diff --git a/src/components/instant_gradients.jl b/src/components/instant_gradients.jl
index a167872..06e3a23 100644
--- a/src/components/instant_gradients.jl
+++ b/src/components/instant_gradients.jl
@@ -1,7 +1,7 @@
 module InstantGradients
 import StaticArrays: SVector, SMatrix
 import JuMP: @constraint
-import ...Variables: VariableType, duration, qval, bmat_gradient, get_free_variable, set_simple_constraints!, effective_time, make_generic, adjust_internal, adjustable
+import ...Variables: VariableType, duration, qval, bmat_gradient, get_free_variable, set_simple_constraints!, effective_time, make_generic, adjust_internal, adjustable, gradient_orientation
 import ...BuildSequences: global_model
 import ..AbstractTypes: EventComponent, GradientWaveform
 
@@ -38,7 +38,7 @@ An [`InstantGradient`](@ref) with a fixed orientation.
 """
 struct InstantGradient1D <: InstantGradient{1}
     qval :: VariableType
-    orientation :: SVector{3, Number}
+    orientation :: SVector{3, Float64}
     group :: Union{Nothing, Symbol}
 end
 
@@ -72,6 +72,8 @@ make_generic(ig::InstantGradient) = ig
 
 adjustable(::InstantGradient) = :gradient
 
+gradient_orientation(ig::InstantGradient{1}) = ig.orientation
+
 function adjust_internal(ig::InstantGradient1D; orientation=nothing, scale=1., rotation=nothing)
     if !isnothing(orientation) && !isnothing(rotation)
         error("Cannot set both the gradient orientation and rotation.")
diff --git a/src/containers/base_sequences.jl b/src/containers/base_sequences.jl
index 964f470..f5482d5 100644
--- a/src/containers/base_sequences.jl
+++ b/src/containers/base_sequences.jl
@@ -4,7 +4,7 @@ Defines [`BaseSequence`](@ref) and [`Sequence`](@ref)
 module BaseSequences
 import StaticArrays: SVector
 import JuMP: @constraint
-import ...Variables: get_free_variable, repetition_time, VariableType, duration, variables, VariableNotAvailable, Variables, set_simple_constraints!, TR, make_generic, gradient_strength, amplitude, phase, gradient_strength3, get_gradient, get_pulse, frequency
+import ...Variables: get_free_variable, repetition_time, VariableType, duration, variables, VariableNotAvailable, Variables, set_simple_constraints!, TR, make_generic, gradient_strength, amplitude, phase, gradient_strength3, get_gradient, get_pulse, frequency, gradient_orientation, get_gradient
 import ...BuildSequences: global_model, global_scanner
 import ...Components: EventComponent, NoGradient
 import ...Scanners: Scanner, B0
@@ -76,6 +76,12 @@ function start_time(bs::BaseSequence{N}, s::Symbol) where {N}
     return start_time(bs, index)
 end
 
+function gradient_orientation(seq::BaseSequence)
+    return gradient_orientation(get_gradient(seq))
+end
+
+gradient_orientation(nt::NamedTuple) = map(gradient_orientation, nt)
+
 
 """
     get_index_single_TR(sequence, index)
diff --git a/src/containers/building_blocks.jl b/src/containers/building_blocks.jl
index f50247d..0ce1465 100644
--- a/src/containers/building_blocks.jl
+++ b/src/containers/building_blocks.jl
@@ -9,7 +9,7 @@ import ..Abstract: ContainerBlock, start_time, readout_times, edge_times, end_ti
 import ...BuildSequences: global_model
 import ...Components: BaseComponent, GradientWaveform, EventComponent, NoGradient, ChangingGradient, ConstantGradient, split_gradient, RFPulseComponent, ReadoutComponent, InstantGradient
 import ...Variables: qval, bmat_gradient, effective_time, get_free_variable, qval3, slew_rate, gradient_strength, amplitude, phase, frequency
-import ...Variables: VariableType, duration, make_generic, get_pulse, get_readout, scanner_constraints!, get_gradient
+import ...Variables: VariableType, duration, make_generic, get_pulse, get_readout, scanner_constraints!, get_gradient, gradient_orientation
 
 """
 Basic BuildingBlock, which can consist of a gradient waveforms with any number of RF pulses/readouts overlaid
@@ -68,6 +68,19 @@ function ndim_grad(bb::BaseBuildingBlock)
     error("$(typeof(bb)) contains both 1D and 3D gradient waveforms.")
 end
 
+function gradient_orientation(bb::BaseBuildingBlock)
+    for (_, ws) in waveform_sequence(bb)
+        if ws isa GradientWaveform{1}
+            return gradient_orientation(ws)
+        end
+    end
+    for (_, e) in events(bb)
+        if e isa InstantGradient{1}
+            return gradient_orientation(e)
+        end
+    end
+end
+
 
 """
     waveform(building_block)
@@ -190,7 +203,16 @@ function qval(bb::BaseBuildingBlock, index1, index2)
     if (!isnothing(index1)) && (index1 == index2)
         return 0.
     end
-    sum([qval(wv) for (_, wv) in waveform_sequence(bb, index1, index2)])
+    res = sum([qval(wv) for (_, wv) in waveform_sequence(bb, index1, index2)])
+
+    t1 = isnothing(index1) ? 0. : start_time(bb, index1)
+    t2 = isnothing(index2) ? duration(bb) : start_time(bb, index2)
+    for (key, event) in events(bb)
+        if event isa InstantGradient && (t1 <= start_time(bb, key) <= t2)
+            res = res .+ qval(event)
+        end
+    end
+    return res
 end
 qval(bb::BaseBuildingBlock) = qval(bb, nothing, nothing)
 
diff --git a/src/variables.jl b/src/variables.jl
index 8754f10..949de4f 100644
--- a/src/variables.jl
+++ b/src/variables.jl
@@ -227,15 +227,7 @@ function bmat_gradient end
 
 Returns the gradient orientation.
 """
-function gradient_orientation(bb::AbstractBlock)
-    if hasproperty(bb, :orientation)
-        return bb.orientation
-    else
-        return gradient_orientation(get_gradient(bb))
-    end
-end
-
-gradient_orientation(nt::NamedTuple) = map(gradient_orientation, nt)
+function gradient_orientation end
 
 
 function effective_time end
@@ -324,7 +316,7 @@ for base_fn in [:qval, :gradient_strength, :slew_rate]
             elseif value isa AbstractVector
                 return value
             else
-                return value .* bb.orientation
+                return value .* gradient_orientation(bb)
             end
         end
     end
diff --git a/test/runtests.jl b/test/runtests.jl
index bba1f16..ef2f03e 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,8 +2,10 @@ using MRIBuilder
 using Test
 
 @testset "MRIBuilder.jl" begin
-    include("test_components.jl")
-    include("test_sequences.jl")
-    include("test_IO.jl")
-    include("test_plot.jl")
+    #include("test_building_blocks.jl")
+    include("test_post_hoc.jl")
+    #include("test_components.jl")
+    #include("test_sequences.jl")
+    #include("test_IO.jl")
+    #include("test_plot.jl")
 end
diff --git a/test/test_post_hoc.jl b/test/test_post_hoc.jl
new file mode 100644
index 0000000..83ced17
--- /dev/null
+++ b/test/test_post_hoc.jl
@@ -0,0 +1,40 @@
+@testset "test_post_hoc.jl" begin
+    import Rotations: RotationVec
+
+    @testset "adjust different components" begin
+        @testset "finite gradients" begin
+            dwi = DWI(bval=1., TE=:min)
+            @test bval(dwi) ≈ 1.
+            qval_orig = qval(dwi[:gradient])
+            @test all(qval3(dwi[:gradient]) .≈ [qval_orig, 0., 0.])
+
+            @testset "scale and change orientation" begin
+                new_dwi = adjust(dwi, DWI=(scale=0.5, orientation=[0., 1., 0.]))
+                @test bval(new_dwi) ≈ 0.25
+                @test all(qval3(new_dwi[:gradient]) .≈ [0., qval_orig/2, 0.])
+            end
+            @testset "Rotate gradient" begin
+                new_dwi = adjust(dwi, gradient=(rotation=RotationVec(0., 0., π/4), ))
+                @test bval(new_dwi) ≈ 1.
+                @test all(qval3(new_dwi[:gradient]) .≈ [qval_orig/√2, qval_orig/√2, 0.])
+            end
+        end
+        @testset "instant gradients" begin
+            dwi = DWI(bval=1., TE=80, Δ=40, gradient=(type=:instant, ))
+            @test bval(dwi) ≈ 1.
+            qval_orig = qval(dwi[:gradient])
+            @test all(qval3(dwi[:gradient]) .≈ [qval_orig, 0., 0.])
+
+            @testset "scale and change orientation" begin
+                new_dwi = adjust(dwi, DWI=(scale=0.5, orientation=[0., 1., 0.]))
+                @test bval(new_dwi) ≈ 0.25
+                @test all(qval3(new_dwi[:gradient]) .≈ [0., qval_orig/2, 0.])
+            end
+            @testset "Rotate gradient" begin
+                new_dwi = adjust(dwi, gradient=(rotation=RotationVec(0., 0., π/4), ))
+                @test bval(new_dwi) ≈ 1.
+                @test all(qval3(new_dwi[:gradient]) .≈ [qval_orig/√2, qval_orig/√2, 0.])
+            end
+        end
+    end
+end
-- 
GitLab