Skip to content
Snippets Groups Projects
Verified Commit e6e1a8a2 authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

update_walker_gradient! should always apply the whole gradient block

parent 3539bf73
No related branches found
No related tags found
No related merge requests found
...@@ -364,29 +364,25 @@ The following steps will be taken: ...@@ -364,29 +364,25 @@ The following steps will be taken:
This requires [`bmat`](@ref) and [`qvec`](@ref) to be implemented for the [`GradientBlock`](@ref). This requires [`bmat`](@ref) and [`qvec`](@ref) to be implemented for the [`GradientBlock`](@ref).
""" """
function update_walker_gradient!(gradient::GradientBlock, walker::PathwayWalker, gradient_start_time::VariableType, internal_start_time, internal_end_time) function update_walker_gradient!(gradient::GradientBlock, walker::PathwayWalker, gradient_start_time::VariableType)
if walker.transverse if walker.transverse
return return
end end
if iszero(internal_start_time) || isnothing(internal_start_time) # make sure the appropriate gradient tracker exists
# only worry about this for the first call key = (gradient.scale, gradient.rotate)
if !(key in keys(walker.gradient_trackers))
# make sure the appropriate gradient tracker exists walker.gradient_trackers[key] = GradientTracker()
key = (gradient.scale, gradient.rotate)
if !(key in keys(walker.gradient_trackers))
walker.gradient_trackers[key] = GradientTracker()
end
tracker = walker.gradient_trackers[key]
# update bmat till start of gradient
tracker.bmat = tracker.bmat .+ (
(tracker.qvec .* tracker.qvec') .*
(gradient_start_time - tracker.last_gradient_time)
)
tracker.last_gradient_time = gradient_start_time
end end
tracker = walker.gradient_trackers[key]
# update bmat till start of gradient
tracker.bmat = tracker.bmat .+ (
(tracker.qvec .* tracker.qvec') .*
(gradient_start_time - tracker.last_gradient_time)
)
tracker.last_gradient_time = gradient_start_time
tracker.bmat = tracker.bmat .+ bmat(gradient, tracker.qvec, internal_start_time, internal_end_time) tracker.bmat = tracker.bmat .+ bmat(gradient, tracker.qvec, internal_start_time, internal_end_time)
tracker.qvec = tracker.qvec .+ qvec(gradient, internal_start_time, internal_end_time) tracker.qvec = tracker.qvec .+ qvec(gradient, internal_start_time, internal_end_time)
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment