From 6ec81d187366bf7e0b5ab95ab2027cb33d93bc58 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Mon, 25 Mar 2024 16:50:24 +0000
Subject: [PATCH] Add instant gradients to pathway

---
 src/components/instant_gradients.jl |  8 ++++----
 src/pathways.jl                     | 19 ++++++++++++++++++-
 2 files changed, 22 insertions(+), 5 deletions(-)

diff --git a/src/components/instant_gradients.jl b/src/components/instant_gradients.jl
index dadd65e..8785641 100644
--- a/src/components/instant_gradients.jl
+++ b/src/components/instant_gradients.jl
@@ -4,7 +4,7 @@ import ...Variables: VariableType, duration, qval, bmat_gradient, get_free_varia
 import ..AbstractTypes: EventComponent, GradientWaveform
 
 """
-    InstantGradient1D(; orientation=nothing, group=nothing, variables...)
+    InstantGradient(; orientation=nothing, group=nothing, variables...)
 
 If the `orientation` is set an [`InstantGradient1D`](@ref) is returned, otherwise an [`InstantGradient3D`](@ref).
 
@@ -16,7 +16,7 @@ If the `orientation` is set an [`InstantGradient1D`](@ref) is returned, otherwis
 - [`qval`](@ref): Spatial frequency on which spins will be dephased due to this pulsed gradient in rad/um (scalar if `orientation` is set and vector otherwise).
 - [`spoiler_scale`](@ref): Length-scale on which spins will be dephased by exactly 2Ï€ in mm.
 """
-abstract type InstantGradient <: EventComponent end
+abstract type InstantGradient{N} <: EventComponent end
 
 function (::Type{InstantGradient})(; orientation=nothing, group=nothing, qval=nothing, variables...)
     if isnothing(orientation)
@@ -31,7 +31,7 @@ end
 """
 An [`InstantGradient`](@ref) with a fixed orientation.
 """
-struct InstantGradient1D <: InstantGradient
+struct InstantGradient1D <: InstantGradient{1}
     qval :: VariableType
     orientation :: SVector{3, Number}
     group :: Union{Nothing, Symbol}
@@ -42,7 +42,7 @@ qval(ig::InstantGradient1D) = ig.qval
 """
 An [`InstantGradient`](@ref) with a variable orientation.
 """
-struct InstantGradient3D <: InstantGradient
+struct InstantGradient3D <: InstantGradient{3}
     qvec :: SVector{3, VariableType}
     group :: Union{Nothing, Symbol}
 end
diff --git a/src/pathways.jl b/src/pathways.jl
index ed08660..359507a 100644
--- a/src/pathways.jl
+++ b/src/pathways.jl
@@ -285,7 +285,6 @@ The function should return `true` if the `Pathway` has reached its end (i.e., th
 function walk_pathway!(seq::Sequence, walker::PathwayWalker, pulse_effects::Vector{Symbol}, nreadout::Ref{Int}) 
     current_TR = 0
     nwait = length(pulse_effects) + nreadout[]
-    println("processing")
     while !(walk_pathway!(seq, walker, pulse_effects, nreadout, current_TR * TR(seq)))
         new_nwait = length(pulse_effects) + nreadout[]
         if nwait == new_nwait
@@ -341,6 +340,8 @@ function walk_pathway!(block::BaseBuildingBlock, walker::PathwayWalker, pulse_ef
         # apply interrupt
         if interruption isa RFPulseComponent
             update_walker_pulse!(walker, pulse_effects, current_time)
+        elseif interruption isa InstantGradient
+            update_walker_instant_gradient!(interruption, walker, current_time)
         end
         current_index = index_inter
         if length(pulse_effects) == 0 && nreadout[] == 0
@@ -491,6 +492,22 @@ function update_walker_gradient!(gradient::GradientWaveform, walker::PathwayWalk
     end
 end
 
+"""
+    update_walker_instant_gradient!(gradient, walker)
+"""
+function update_walker_instant_gradient!(gradient::InstantGradient{N}, walker::PathwayWalker, gradient_start_time::VariableType) where {N}
+    if N == 1
+        qvec3 = qval(gradient) .* gradient.orientation
+    else
+        qvec3 = qval(gradient)
+    end
+    for key in (isnothing(gradient.group) ? [nothing] : [nothing, gradient.group])
+        update_gradient_tracker_till_time!(walker, key, gradient_start_time)
+        tracker = walker.gradient_trackers[key]
+        tracker.qvec = tracker.qvec  .+ qvec3
+    end
+end
+
 """
     duration_state_index(transverse, positive)
 
-- 
GitLab