From 4d3dcc3c9a487b3521dd6a8dccd6bc8874f8498f Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Wed, 6 Mar 2024 14:37:51 +0000
Subject: [PATCH] Switch from split_times to single split timestep

---
 src/components/abstract_types.jl              | 12 +++--
 src/components/components.jl                  |  2 +-
 .../gradient_waveforms/gradient_waveforms.jl  |  4 +-
 src/components/pulses/constant_pulses.jl      |  4 +-
 src/components/pulses/generic_pulses.jl       | 49 +++++++------------
 src/components/pulses/sinc_pulses.jl          |  8 ++-
 src/components/readouts/readouts.jl           |  4 +-
 src/containers/linearise.jl                   | 13 +++--
 8 files changed, 46 insertions(+), 50 deletions(-)

diff --git a/src/components/abstract_types.jl b/src/components/abstract_types.jl
index 6ef5cb2..3aeb683 100644
--- a/src/components/abstract_types.jl
+++ b/src/components/abstract_types.jl
@@ -37,10 +37,16 @@ abstract type ReadoutComponent <: EventComponent end
 
 
 """
-    split_times(component, t1, t2, precision)
+    split_timestep(component, precision)
 
-Indicates at what timepoints a given [`BaseComponent`](@ref) should be split during linearisation to achieve the given precision.
+Indicates the maximum timestep that a component can be linearised with and still achieve the required `precision`.
+
+Typically, this will be determined by the maximum second derivative:
+
+``\\sqrt{\\frac{2 \\epsilon}{max(|d^2y/dx^2|)}}``
+
+It should be infinite if the component is linear.
 """
-function split_times end
+function split_timestep end
 
 end
\ No newline at end of file
diff --git a/src/components/components.jl b/src/components/components.jl
index 4fa1035..880f6d2 100644
--- a/src/components/components.jl
+++ b/src/components/components.jl
@@ -5,7 +5,7 @@ include("instant_gradients.jl")
 include("pulses/pulses.jl")
 include("readouts/readouts.jl")
 
-import .AbstractTypes: BaseComponent, GradientWaveform, EventComponent, RFPulseComponent, ReadoutComponent, split_times
+import .AbstractTypes: BaseComponent, GradientWaveform, EventComponent, RFPulseComponent, ReadoutComponent, split_timestep
 import .GradientWaveforms: ConstantGradient, ChangingGradient, NoGradient, split_gradient
 import .InstantGradients: InstantGradient
 import .Pulses: GenericPulse, InstantPulse, SincPulse, ConstantPulse
diff --git a/src/components/gradient_waveforms/gradient_waveforms.jl b/src/components/gradient_waveforms/gradient_waveforms.jl
index 61c00a0..bacba79 100644
--- a/src/components/gradient_waveforms/gradient_waveforms.jl
+++ b/src/components/gradient_waveforms/gradient_waveforms.jl
@@ -21,11 +21,11 @@ include("constant_gradient_blocks.jl")
 include("no_gradient_blocks.jl")
 
 
-import ..AbstractTypes: GradientWaveform, split_times
+import ..AbstractTypes: GradientWaveform, split_timestep
 import .NoGradientBlocks: NoGradient
 import .ChangingGradientBlocks: ChangingGradient, split_gradient
 import .ConstantGradientBlocks: ConstantGradient
 
-split_times(wv::GradientWaveform, t1, t2, precision) = [t1, t2]
+split_timestep(wv::GradientWaveform, precision) = Inf
 
 end
\ No newline at end of file
diff --git a/src/components/pulses/constant_pulses.jl b/src/components/pulses/constant_pulses.jl
index c08db09..f2891a0 100644
--- a/src/components/pulses/constant_pulses.jl
+++ b/src/components/pulses/constant_pulses.jl
@@ -1,6 +1,6 @@
 module ConstantPulses
 import JuMP: @constraint
-import ...AbstractTypes: RFPulseComponent, split_times
+import ...AbstractTypes: RFPulseComponent, split_timestep
 import ....BuildSequences: global_model
 import ....Variables: duration, amplitude, effective_time, flip_angle, phase, inverse_bandwidth, VariableType, set_simple_constraints!, frequency, make_generic, get_free_variable
 import ..GenericPulses: GenericPulse
@@ -61,7 +61,7 @@ function make_generic(block::ConstantPulse)
 end
 
 
-split_times(pulse::ConstantPulse, t1, t2, precision) = [t1, t2]
+split_timestep(pulse::ConstantPulse, precision) = Inf
 
 
 end
\ No newline at end of file
diff --git a/src/components/pulses/generic_pulses.jl b/src/components/pulses/generic_pulses.jl
index 6a5750a..8b9f72e 100644
--- a/src/components/pulses/generic_pulses.jl
+++ b/src/components/pulses/generic_pulses.jl
@@ -1,6 +1,7 @@
 module GenericPulses
 
-import ...AbstractTypes: RFPulseComponent, split_times
+import Polynomials: fit
+import ...AbstractTypes: RFPulseComponent, split_timestep
 import ....Variables: duration, amplitude, effective_time, flip_angle, make_generic, phase
 
 
@@ -78,38 +79,26 @@ end
 
 make_generic(gp::GenericPulse) = gp
 
-function split_times(gp::GenericPulse, t1, t2, precision)
-    real_amplitude_precision = precision * amplitude(gp)
-    real_phase_precision = 180 * precision
-
-    current_index = find(t -> t >= t1, gp.time)
-    final_index = find(t -> t > t2, gp.time)
-    if isnothing(final_index)
-        final_index = length(gp.time)
-    end
-    times = [t1]
-    while gp.time[current_index] < t2
-        current_time = gp.time[current_index]
-        current_amplitude = gp.amplitude[current_index]
-        current_phase = gp.phase(current_index)
-        slope_amplitude = (gp.amplitude[current_index + 1] - current_amplitude) / (gp.time[current_index + 1] - current_time)
-        slope_phase = (gp.phase[current_index + 1] - current_phase) / (gp.time[current_index + 1] - current_time)
-
-        next_index = current_index
-        for next_index in current_index+2:final_index
-            if (
-                abs(slope_amplitude * (gp.time[next_index] - current_time) - gp.amplitude[next_index]) > real_amplitude_precision ||
-                abs(slope_phase * (gp.time[next_index] - current_time) - gp.phase[next_index]) > real_phase_precision
-            )
-                next_index -= 1
-                push!(times, gp.time[next_index])
-                break
+function split_timestep(gp::GenericPulse, precision)
+    function second_derivative(arr)
+        max_second_der = 0.
+        for index in 2:length(arr)-1
+            poly = fit(gp.times[index-1:index+1], arr[index-1:index+1])
+            second_der = abs(poly.coeffs[end])
+            if second_der > max_second_der
+                max_second_der = second_der
             end
         end
-        current_index = next_index
+        return max_second_der
+    end
+    max_second_der = max(
+        second_derivative(gp.amplitude ./ maximum(gp.amplitude)),
+        second_derivative(gp.phase ./ 360),
+    )
+    if iszero(max_second_der)
+        return Inf
     end
-    push!(times, t2)
-    return times
+    return sqrt(2 * precision / max_second_der)
 end
 
 
diff --git a/src/components/pulses/sinc_pulses.jl b/src/components/pulses/sinc_pulses.jl
index fa562c9..e5d43bb 100644
--- a/src/components/pulses/sinc_pulses.jl
+++ b/src/components/pulses/sinc_pulses.jl
@@ -3,7 +3,7 @@ import JuMP: @constraint
 import QuadGK: quadgk
 import ....BuildSequences: global_model
 import ....Variables: duration, amplitude, effective_time, flip_angle, phase, inverse_bandwidth, VariableType, set_simple_constraints!, frequency, make_generic, get_free_variable
-import ...AbstractTypes: RFPulseComponent, split_times
+import ...AbstractTypes: RFPulseComponent, split_timestep
 import ..GenericPulses: GenericPulse
 
 """
@@ -101,11 +101,9 @@ function make_generic(block::SincPulse)
     return GenericPulse(times, amplitudes, phases, effective_time(block))
 end
 
-function split_times(block::SincPulse, t1, t2, precision)
+function split_timestep(block::SincPulse, precision)
     max_second_derivative = π^2/3 * (amplitude(block) / lobe_duration(block))^2
-    step_size = sqrt(2 * precision / max_second_derivative)
-    npoints = Int(div(t2 - t1, step_size, RoundUp))
-    return range(t1, t2, length=npoints)
+    return sqrt(2 * precision / max_second_derivative)
 end
 
 end
\ No newline at end of file
diff --git a/src/components/readouts/readouts.jl b/src/components/readouts/readouts.jl
index 45e3432..7271159 100644
--- a/src/components/readouts/readouts.jl
+++ b/src/components/readouts/readouts.jl
@@ -2,9 +2,9 @@ module Readouts
 include("ADCs.jl")
 include("single_readouts.jl")
 
-import ..AbstractTypes: ReadoutComponent, split_times
+import ..AbstractTypes: ReadoutComponent, split_timestep
 import .ADCs: ADC, readout_times
 import .SingleReadouts: SingleReadout
 
-split_times(rc::ReadoutComponent, t1, t2, precision) = [t1, t2]
+split_times(rc::ReadoutComponent, precision) = Inf
 end
\ No newline at end of file
diff --git a/src/containers/linearise.jl b/src/containers/linearise.jl
index 5ca43bd..86ac78c 100644
--- a/src/containers/linearise.jl
+++ b/src/containers/linearise.jl
@@ -1,6 +1,6 @@
 module Linearise
 import StaticArrays: SVector
-import ...Components: GradientWaveform, split_times
+import ...Components: GradientWaveform, split_timestep
 import ...Variables: amplitude, phase, gradient_strength3, duration
 import ..Abstract: edge_times, start_time, end_time
 import ..BaseSequences: BaseSequence, Sequence
@@ -45,15 +45,18 @@ function linearise(bb::BaseBuildingBlock; precision=0.01)
         tmean = (t1 + t2) / 2
 
         # determine where to split this domain
-        possible_splits = Any[[t1, t2], ]
+        timestep = Inf
         for key in keys(bb)
             if !(start_time(bb, key) < tmean < end_time(bb, key))
                 continue
             end
-            block = bb[key]
-            append!(possible_splits, start_time(bb, key) .+ split_times(block, t1 - start_time(bb, key), t2 - start_time(bb, key), precision))
+            new_timestep = split_timestep(block, precision)
+            if new_timestep < timestep
+                timestep = new_timestep
+            end
         end
-        tsplits = argmax(length, possible_splits)
+        nsteps = div(t2 - t1, timestep, RoundUp)
+        tsplits = range(t1, t2, length=nsteps)
 
         for (t1b, t2b) in zip(tsplits[1:end-1], tsplits[2:end])
             tmeanb = (t1b + t2b) / 2
-- 
GitLab