From e35e2ca0cdc586eb024700f32863e0ea20b1abf4 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Tue, 30 Jan 2024 14:50:17 +0000
Subject: [PATCH] Add helper update_*_till_time functions

---
 src/pathways.jl | 97 ++++++++++++++++++++++++++++++++-----------------
 1 file changed, 64 insertions(+), 33 deletions(-)

diff --git a/src/pathways.jl b/src/pathways.jl
index 20b8d40..6d225ec 100644
--- a/src/pathways.jl
+++ b/src/pathways.jl
@@ -55,7 +55,7 @@ struct Pathway
     readout_index :: Integer
 
     # computed
-    duration_states :: Dict{Any, SVector{4, <:VariableType}}
+    duration_states :: SVector{4, <:VariableType}
     qvec :: Dict{Any, SVector{3, <:VariableType}}
     bmat :: Dict{Any, SMatrix{3, 3, <:VariableType, 9}}
 end
@@ -67,7 +67,7 @@ function Pathway(sequence::Sequence, pulse_effects::AbstractVector, readout_inde
         sequence,
         pulse_effects,
         readout_index,
-        Dict(k => SVector{4}(v) for (k, v) in pairs(walker.duration_states)),
+        SVector{4}(walker.duration_states),
         Dict(k => SVector{3}(v) for (k, v) in pairs(walker.qvec)),
         Dict(k => SMatrix{3, 3}(v) for (k, v) in pairs(walker.bmat)),
     )
@@ -295,6 +295,7 @@ function walk_pathway!(::InstantReadout, walker::PathwayWalker, pulse_effects::V
     end
     nreadout[] -= 1
     if iszero(nreadout[])
+        update_walker_till_time!(walker, block_start_time)
         return true
     elseif nreadout[] < 0
         error("Pathway walker continued past the point where it should have ended. Did you start with a negative `nreadout`?")
@@ -303,6 +304,59 @@ function walk_pathway!(::InstantReadout, walker::PathwayWalker, pulse_effects::V
 end
 
 
+"""
+    update_walker_till_time!(walker::PathwayWalker, new_time[, (rotate, scale)])
+
+Updates all parts of a [`PathwayWalker`](@ref) up to the given time.
+
+This updates the `walker.duration_states` and the `bmat` for each gradient tracker.
+If `rotate` and `scale` are provided, then only the gradient tracker matching those properties will be updated.
+If that gradient tracker does not exist, it will be created.
+
+This function is used to get the `walker` up to date till the start of a gradient, pulse, or final readout.
+"""
+function update_walker_till_time!(walker::PathwayWalker, new_time::Float64, gradient_key=nothing)
+    # update duration state and pulse time
+    @assert new_time >= walker.last_pulse_time
+    index = duration_state_index(walker.is_transverse, walker.is_positive)
+    walker.duration_states[index] = walker.duration_states[index] + (new_time - walker.last_pulse_time)
+    walker.last_pulse_time = new_time
+
+    if isnothing(gradient_key)
+        for tracker in values(walker.gradient_trackers)
+            update_gradient_tracker_till_time!(tracker, new_time)
+        end
+    else
+        update_gradient_tracker_till_time!(walker.gradient_trackers, gradient_key, new_time)
+    end
+end
+
+"""
+    update_gradient_tracker_till_time!(walker::PathwayWalker, key, new_time)
+    update_gradient_tracker_till_time!(tracker::GradientTracker, new_time)
+
+Update the `bmat` for any time passed since the last update (assuming there will no gradients during that period).
+
+The `bmat` is updated with the outer produce of `qvec` with itself multiplied by the time since the last update.
+
+When called with the first signature the tracker will be created from scratch if a tracker with that `key` does not exist.
+"""
+function update_gradient_tracker_till_time!(walker::PathwayWalker, key::Tuple, new_time::Float64)
+    if key in keys(walker.gradient_trackers)
+        walker.gradient_trackers[key] = GradientTracker()
+    end
+    update_gradient_tracker_till_time!(walker.gradient_trackers[key], new_time)
+end
+
+function update_gradient_tracker_till_time!(gradient_tracker::GradientTracker, new_time::Float64)
+    @assert new_time >= gradient_tracker.last_gradient_time
+    gradient_tracker.bmat += (
+        (gradient_tracker.qvec .* gradient_tracker.qvec') .* 
+        (pulse_time - gradient_tracker.last_gradient_time)
+    )
+    gradient_tracker.last_gradient_time = pulse_time
+end
+
 """
     update_walker_pulse!(walker::PathwayWalker, pulse_effects::Vector, pulse_time)
 
@@ -327,23 +381,7 @@ function update_walker_pulse!(walker::PathwayWalker, pulse_effects::AbstractVect
         return
     end
 
-    # update qvec/bmat
-    if walker.is_transverse
-        for gradient_tracker in values(walker.gradient_trackers)
-            gradient_tracker.bmat += (
-                (gradient_tracker.qvec .* gradient_tracker.qvec') .* 
-                (pulse_time - gradient_tracker.last_gradient_time)
-            )
-            gradient_tracker.last_gradient_time = pulse_time
-        end
-    end
-    prev_sign = walker.is_positive
-    
-    # update durations
-    index = duration_state_index(walker.is_transverse, walker.is_positive)
-    walker.duration_states[index] = walker.duration_states[index] + (pulse_time - walker.last_pulse_time)
-
-    walker.last_pulse_time = pulse_time
+    update_walker_till_time!(walker, pulse_time)
 
     # -transverse, +longitudinal, +transverse, -longitudinal, -transverse, +longitudinal
     ordering = [(true, false), (false, true), (true, true), (false, false), (true, false)]
@@ -386,22 +424,15 @@ function update_walker_gradient!(gradient::GradientBlock, walker::PathwayWalker,
         return
     end
 
-    # make sure the appropriate gradient tracker exists
+    # update gradient tracker till start of gradient
     key = (gradient.scale, gradient.rotate)
-    if !(key in keys(walker.gradient_trackers))
-        walker.gradient_trackers[key] = GradientTracker()
-    end
-    tracker = walker.gradient_trackers[key]
-
-    # update bmat till start of gradient
-    tracker.bmat = tracker.bmat .+ (
-        (tracker.qvec .* tracker.qvec') .* 
-        (gradient_start_time - tracker.last_gradient_time)
-    )
+    update_gradient_tracker_till_time!(gradient, key, gradient_start_time)
 
-    tracker.last_gradient_time = gradient_start_time
-    tracker.bmat = tracker.bmat .+ bmat(gradient, tracker.qvec, internal_start_time, internal_end_time)
-    tracker.qvec = tracker.qvec .+ qvec(gradient, internal_start_time, internal_end_time)
+    # update qvec/bmat during gradient
+    tracker = walker.gradient_trackers[key]
+    tracker.bmat = tracker.bmat .+ bmat(gradient, tracker.qvec)
+    tracker.qvec = tracker.qvec .+ qvec(gradient)
+    tracker.last_gradient_time = gradient_start_time + duration(gradient)
 end
 
 """
-- 
GitLab