From ab8c3ffa8caafaac3bddfccee6c7a031ea1c1dd7 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Mon, 4 Mar 2024 15:15:08 +0000
Subject: [PATCH] Fix plotting off-resonance pulses

Correctly show RFx and RFy
---
 src/plot.jl | 21 ++++++++++++++++++++-
 1 file changed, 20 insertions(+), 1 deletion(-)

diff --git a/src/plot.jl b/src/plot.jl
index 67ea626..beb74a0 100644
--- a/src/plot.jl
+++ b/src/plot.jl
@@ -112,7 +112,26 @@ function SequenceDiagram(bbb::BaseBuildingBlock)
                 if event isa InstantPulse
                     kwargs[symbol] = SinglePlotLine([0., duration(bbb)], [0., 0.], [delay], [flip_angle(event) * fn(phase(event))])
                 else
-                    points = [(t + delay, a * fn(p)) for (t, a, p) in zip(event.time, event.amplitude, event.phase)]
+                    points = Tuple{Float64, Float64}[]
+                    t_prev = p_prev = a_prev = nothing
+                    for (t, a, p) in zip(event.time, event.amplitude, event.phase)
+                        if !isnothing(t_prev)
+                            prev_phase_group = div(p_prev, 90, RoundDown)
+                            phase_group = div(p, 90, RoundDown)
+                            if phase_group != prev_phase_group
+                                for edge in (phase_group < prev_phase_group ? (prev_phase_group:-1:phase_group+1) : (prev_phase_group+1:phase_group))
+                                    edge_phase = edge * 90
+                                    edge_time = (abs(edge_phase - p_prev) * t + abs(edge_phase - p) * t_prev) / abs(p - p_prev)
+                                    edge_amplitude = (abs(edge_phase - p_prev) * a + abs(edge_phase - p) * a_prev) / abs(p - p_prev)
+                                    push!(points, (edge_time + delay, edge_amplitude * fn(edge_phase)))
+                                end
+                            end
+                        end
+                        push!(points, (t + delay, a * fn(p)))
+                        t_prev = t
+                        p_prev = p
+                        a_prev = a
+                    end
                     kwargs[symbol] = SinglePlotLine(points, duration(bbb))
                 end
             end
-- 
GitLab