From e6e1a8a2437518a4721dc3d7f516131a8b7f8754 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Tue, 30 Jan 2024 14:20:28 +0000
Subject: [PATCH] update_walker_gradient! should always apply the whole
 gradient block

---
 src/pathways.jl | 30 +++++++++++++-----------------
 1 file changed, 13 insertions(+), 17 deletions(-)

diff --git a/src/pathways.jl b/src/pathways.jl
index dfa5805..8eee5c7 100644
--- a/src/pathways.jl
+++ b/src/pathways.jl
@@ -364,29 +364,25 @@ The following steps will be taken:
 
 This requires [`bmat`](@ref) and [`qvec`](@ref) to be implemented for the [`GradientBlock`](@ref).
 """
-function update_walker_gradient!(gradient::GradientBlock, walker::PathwayWalker, gradient_start_time::VariableType, internal_start_time, internal_end_time)
+function update_walker_gradient!(gradient::GradientBlock, walker::PathwayWalker, gradient_start_time::VariableType)
     if walker.transverse
         return
     end
 
-    if iszero(internal_start_time) || isnothing(internal_start_time)
-        # only worry about this for the first call
-
-        # make sure the appropriate gradient tracker exists
-        key = (gradient.scale, gradient.rotate)
-        if !(key in keys(walker.gradient_trackers))
-            walker.gradient_trackers[key] = GradientTracker()
-        end
-        tracker = walker.gradient_trackers[key]
-
-        # update bmat till start of gradient
-        tracker.bmat = tracker.bmat .+ (
-            (tracker.qvec .* tracker.qvec') .* 
-            (gradient_start_time - tracker.last_gradient_time)
-        )
-        tracker.last_gradient_time = gradient_start_time
+    # make sure the appropriate gradient tracker exists
+    key = (gradient.scale, gradient.rotate)
+    if !(key in keys(walker.gradient_trackers))
+        walker.gradient_trackers[key] = GradientTracker()
     end
+    tracker = walker.gradient_trackers[key]
+
+    # update bmat till start of gradient
+    tracker.bmat = tracker.bmat .+ (
+        (tracker.qvec .* tracker.qvec') .* 
+        (gradient_start_time - tracker.last_gradient_time)
+    )
 
+    tracker.last_gradient_time = gradient_start_time
     tracker.bmat = tracker.bmat .+ bmat(gradient, tracker.qvec, internal_start_time, internal_end_time)
     tracker.qvec = tracker.qvec .+ qvec(gradient, internal_start_time, internal_end_time)
 end
-- 
GitLab