From afda6fd279affe6d68bc3ab17ee1746daab8c710 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Wed, 17 Apr 2024 15:41:49 +0100
Subject: [PATCH] Fix GenericPulse frequency calculation

---
 src/components/pulses/generic_pulses.jl | 12 +++++------
 test/runtests.jl                        |  1 +
 test/test_components.jl                 | 27 +++++++++++++++++++++++++
 test/test_sequences.jl                  | 10 +++++----
 4 files changed, 40 insertions(+), 10 deletions(-)
 create mode 100644 test/test_components.jl

diff --git a/src/components/pulses/generic_pulses.jl b/src/components/pulses/generic_pulses.jl
index d70c0d7..88cdbd9 100644
--- a/src/components/pulses/generic_pulses.jl
+++ b/src/components/pulses/generic_pulses.jl
@@ -102,16 +102,16 @@ end
 function frequency(gp::GenericPulse, time::Number)
     i2 = findfirst(t -> t > time, gp.time)
     if isnothing(i2)
-        @assert time ≈ fp.time[end]
-        i2 = length(time)
+        @assert time ≈ gp.time[end]
+        i2 = length(gp.time)
     end
-    if i2 != length(time) && time ≈ fp.time[i2 + 1]
-        i2 += 1
+    if !isone(i2) && time ≈ gp.time[i2 - 1]
+        i2 -= 1
     end
-    if time ≈ fp.time[i2]
+    if time ≈ gp.time[i2]
         if i2 == 1
             return (gp.phase[2] - gp.phase[1]) / (gp.time[2] - gp.time[1]) / 360
-        elseif i2 == length(time)
+        elseif i2 == length(gp.time)
             return (gp.phase[end] - gp.phase[end-1]) / (gp.time[end] - gp.time[end-1]) / 360
         end
         return (gp.phase[i2 + 1] - gp.phase[i2 - 1]) / (gp.time[i2 + 1] - gp.time[i2 - 1]) / 360
diff --git a/test/runtests.jl b/test/runtests.jl
index 2ba1714..bba1f16 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,6 +2,7 @@ using MRIBuilder
 using Test
 
 @testset "MRIBuilder.jl" begin
+    include("test_components.jl")
     include("test_sequences.jl")
     include("test_IO.jl")
     include("test_plot.jl")
diff --git a/test/test_components.jl b/test/test_components.jl
new file mode 100644
index 0000000..bdc9c27
--- /dev/null
+++ b/test/test_components.jl
@@ -0,0 +1,27 @@
+@testset "test_components.jl" begin
+    @testset "probing GenericPulse" begin
+        gp = GenericPulse([0., 1., 3., 4.], [1., 2., 3., 0.], [0., 90., 180., 180.])
+        @test amplitude(gp, 0.) ≈ 1.
+        @test amplitude(gp, 1.) ≈ 2.
+        @test amplitude(gp, 0.25) ≈ 1.25
+        @test amplitude(gp, 2.5) ≈ 2.75
+        @test amplitude(gp, 3.5) ≈ 1.5
+        @test amplitude(gp, 4.) ≈ 0.
+
+        @test phase(gp, 0.) ≈ 0.
+        @test phase(gp, 1.) ≈ 90.
+        @test phase(gp, 0.25) ≈ 90/4
+        @test phase(gp, 2.5) ≈ 90 * 7/4
+        @test phase(gp, 3.5) ≈ 180.
+        @test phase(gp, 4.) ≈ 180.
+
+        @test frequency(gp, 0.) ≈ 1/4
+        @test frequency(gp, 1.) ≈ 1/6
+        @test frequency(gp, 0.25) ≈ 1/4
+        @test frequency(gp, 2.5) ≈ 1/8
+        @test frequency(gp, 3.) ≈ 1/12
+        @test frequency(gp, 3.5) ≈ 0.
+        @test frequency(gp, 4.) ≈ 0.
+    end
+
+end
diff --git a/test/test_sequences.jl b/test/test_sequences.jl
index d82c008..fea634b 100644
--- a/test/test_sequences.jl
+++ b/test/test_sequences.jl
@@ -113,14 +113,16 @@
                 @test 1. - t_pulse ≈ min_rise_time rtol=1e-4
                 @test flip_angle(pulse) ≈ 90.
                 @test iszero(phase(pulse))
+                @test iszero(phase(seq, 1.))
+                @test iszero(frequency(seq, 1.))
+
                 @test isnothing(get_pulse(seq, 10.))
+                @test isnan(phase(seq, 10.))
+                @test isnan(frequency(seq, 10.))
+    
                 gp = GenericPulse(pulse, 0., 1.)
                 @test gp.amplitude[1] ≈ 0. atol=1e-8
                 @test gp.amplitude[end] ≈ amplitude(pulse, 1.) rtol=1e-2
-                @test iszero(phase(pulse, 1.)) rtol=1e-2
-                @test iszero(frequency(pulse, 1.)) rtol=1e-2
-                @test isnan(phase(pulse, 1.)) rtol=1e-2
-                @test isnan(frequency(pulse, 1.)) rtol=1e-2
                 @test all(iszero.(gp.phase))
 
                 (pulse, t_pulse) = get_pulse(seq, 35.)
-- 
GitLab