From e6e1a8a2437518a4721dc3d7f516131a8b7f8754 Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk> Date: Tue, 30 Jan 2024 14:20:28 +0000 Subject: [PATCH] update_walker_gradient! should always apply the whole gradient block --- src/pathways.jl | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/pathways.jl b/src/pathways.jl index dfa5805..8eee5c7 100644 --- a/src/pathways.jl +++ b/src/pathways.jl @@ -364,29 +364,25 @@ The following steps will be taken: This requires [`bmat`](@ref) and [`qvec`](@ref) to be implemented for the [`GradientBlock`](@ref). """ -function update_walker_gradient!(gradient::GradientBlock, walker::PathwayWalker, gradient_start_time::VariableType, internal_start_time, internal_end_time) +function update_walker_gradient!(gradient::GradientBlock, walker::PathwayWalker, gradient_start_time::VariableType) if walker.transverse return end - if iszero(internal_start_time) || isnothing(internal_start_time) - # only worry about this for the first call - - # make sure the appropriate gradient tracker exists - key = (gradient.scale, gradient.rotate) - if !(key in keys(walker.gradient_trackers)) - walker.gradient_trackers[key] = GradientTracker() - end - tracker = walker.gradient_trackers[key] - - # update bmat till start of gradient - tracker.bmat = tracker.bmat .+ ( - (tracker.qvec .* tracker.qvec') .* - (gradient_start_time - tracker.last_gradient_time) - ) - tracker.last_gradient_time = gradient_start_time + # make sure the appropriate gradient tracker exists + key = (gradient.scale, gradient.rotate) + if !(key in keys(walker.gradient_trackers)) + walker.gradient_trackers[key] = GradientTracker() end + tracker = walker.gradient_trackers[key] + + # update bmat till start of gradient + tracker.bmat = tracker.bmat .+ ( + (tracker.qvec .* tracker.qvec') .* + (gradient_start_time - tracker.last_gradient_time) + ) + tracker.last_gradient_time = gradient_start_time tracker.bmat = tracker.bmat .+ bmat(gradient, tracker.qvec, internal_start_time, internal_end_time) tracker.qvec = tracker.qvec .+ qvec(gradient, internal_start_time, internal_end_time) end -- GitLab