From 4d3dcc3c9a487b3521dd6a8dccd6bc8874f8498f Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk> Date: Wed, 6 Mar 2024 14:37:51 +0000 Subject: [PATCH] Switch from split_times to single split timestep --- src/components/abstract_types.jl | 12 +++-- src/components/components.jl | 2 +- .../gradient_waveforms/gradient_waveforms.jl | 4 +- src/components/pulses/constant_pulses.jl | 4 +- src/components/pulses/generic_pulses.jl | 49 +++++++------------ src/components/pulses/sinc_pulses.jl | 8 ++- src/components/readouts/readouts.jl | 4 +- src/containers/linearise.jl | 13 +++-- 8 files changed, 46 insertions(+), 50 deletions(-) diff --git a/src/components/abstract_types.jl b/src/components/abstract_types.jl index 6ef5cb2..3aeb683 100644 --- a/src/components/abstract_types.jl +++ b/src/components/abstract_types.jl @@ -37,10 +37,16 @@ abstract type ReadoutComponent <: EventComponent end """ - split_times(component, t1, t2, precision) + split_timestep(component, precision) -Indicates at what timepoints a given [`BaseComponent`](@ref) should be split during linearisation to achieve the given precision. +Indicates the maximum timestep that a component can be linearised with and still achieve the required `precision`. + +Typically, this will be determined by the maximum second derivative: + +``\\sqrt{\\frac{2 \\epsilon}{max(|d^2y/dx^2|)}}`` + +It should be infinite if the component is linear. """ -function split_times end +function split_timestep end end \ No newline at end of file diff --git a/src/components/components.jl b/src/components/components.jl index 4fa1035..880f6d2 100644 --- a/src/components/components.jl +++ b/src/components/components.jl @@ -5,7 +5,7 @@ include("instant_gradients.jl") include("pulses/pulses.jl") include("readouts/readouts.jl") -import .AbstractTypes: BaseComponent, GradientWaveform, EventComponent, RFPulseComponent, ReadoutComponent, split_times +import .AbstractTypes: BaseComponent, GradientWaveform, EventComponent, RFPulseComponent, ReadoutComponent, split_timestep import .GradientWaveforms: ConstantGradient, ChangingGradient, NoGradient, split_gradient import .InstantGradients: InstantGradient import .Pulses: GenericPulse, InstantPulse, SincPulse, ConstantPulse diff --git a/src/components/gradient_waveforms/gradient_waveforms.jl b/src/components/gradient_waveforms/gradient_waveforms.jl index 61c00a0..bacba79 100644 --- a/src/components/gradient_waveforms/gradient_waveforms.jl +++ b/src/components/gradient_waveforms/gradient_waveforms.jl @@ -21,11 +21,11 @@ include("constant_gradient_blocks.jl") include("no_gradient_blocks.jl") -import ..AbstractTypes: GradientWaveform, split_times +import ..AbstractTypes: GradientWaveform, split_timestep import .NoGradientBlocks: NoGradient import .ChangingGradientBlocks: ChangingGradient, split_gradient import .ConstantGradientBlocks: ConstantGradient -split_times(wv::GradientWaveform, t1, t2, precision) = [t1, t2] +split_timestep(wv::GradientWaveform, precision) = Inf end \ No newline at end of file diff --git a/src/components/pulses/constant_pulses.jl b/src/components/pulses/constant_pulses.jl index c08db09..f2891a0 100644 --- a/src/components/pulses/constant_pulses.jl +++ b/src/components/pulses/constant_pulses.jl @@ -1,6 +1,6 @@ module ConstantPulses import JuMP: @constraint -import ...AbstractTypes: RFPulseComponent, split_times +import ...AbstractTypes: RFPulseComponent, split_timestep import ....BuildSequences: global_model import ....Variables: duration, amplitude, effective_time, flip_angle, phase, inverse_bandwidth, VariableType, set_simple_constraints!, frequency, make_generic, get_free_variable import ..GenericPulses: GenericPulse @@ -61,7 +61,7 @@ function make_generic(block::ConstantPulse) end -split_times(pulse::ConstantPulse, t1, t2, precision) = [t1, t2] +split_timestep(pulse::ConstantPulse, precision) = Inf end \ No newline at end of file diff --git a/src/components/pulses/generic_pulses.jl b/src/components/pulses/generic_pulses.jl index 6a5750a..8b9f72e 100644 --- a/src/components/pulses/generic_pulses.jl +++ b/src/components/pulses/generic_pulses.jl @@ -1,6 +1,7 @@ module GenericPulses -import ...AbstractTypes: RFPulseComponent, split_times +import Polynomials: fit +import ...AbstractTypes: RFPulseComponent, split_timestep import ....Variables: duration, amplitude, effective_time, flip_angle, make_generic, phase @@ -78,38 +79,26 @@ end make_generic(gp::GenericPulse) = gp -function split_times(gp::GenericPulse, t1, t2, precision) - real_amplitude_precision = precision * amplitude(gp) - real_phase_precision = 180 * precision - - current_index = find(t -> t >= t1, gp.time) - final_index = find(t -> t > t2, gp.time) - if isnothing(final_index) - final_index = length(gp.time) - end - times = [t1] - while gp.time[current_index] < t2 - current_time = gp.time[current_index] - current_amplitude = gp.amplitude[current_index] - current_phase = gp.phase(current_index) - slope_amplitude = (gp.amplitude[current_index + 1] - current_amplitude) / (gp.time[current_index + 1] - current_time) - slope_phase = (gp.phase[current_index + 1] - current_phase) / (gp.time[current_index + 1] - current_time) - - next_index = current_index - for next_index in current_index+2:final_index - if ( - abs(slope_amplitude * (gp.time[next_index] - current_time) - gp.amplitude[next_index]) > real_amplitude_precision || - abs(slope_phase * (gp.time[next_index] - current_time) - gp.phase[next_index]) > real_phase_precision - ) - next_index -= 1 - push!(times, gp.time[next_index]) - break +function split_timestep(gp::GenericPulse, precision) + function second_derivative(arr) + max_second_der = 0. + for index in 2:length(arr)-1 + poly = fit(gp.times[index-1:index+1], arr[index-1:index+1]) + second_der = abs(poly.coeffs[end]) + if second_der > max_second_der + max_second_der = second_der end end - current_index = next_index + return max_second_der + end + max_second_der = max( + second_derivative(gp.amplitude ./ maximum(gp.amplitude)), + second_derivative(gp.phase ./ 360), + ) + if iszero(max_second_der) + return Inf end - push!(times, t2) - return times + return sqrt(2 * precision / max_second_der) end diff --git a/src/components/pulses/sinc_pulses.jl b/src/components/pulses/sinc_pulses.jl index fa562c9..e5d43bb 100644 --- a/src/components/pulses/sinc_pulses.jl +++ b/src/components/pulses/sinc_pulses.jl @@ -3,7 +3,7 @@ import JuMP: @constraint import QuadGK: quadgk import ....BuildSequences: global_model import ....Variables: duration, amplitude, effective_time, flip_angle, phase, inverse_bandwidth, VariableType, set_simple_constraints!, frequency, make_generic, get_free_variable -import ...AbstractTypes: RFPulseComponent, split_times +import ...AbstractTypes: RFPulseComponent, split_timestep import ..GenericPulses: GenericPulse """ @@ -101,11 +101,9 @@ function make_generic(block::SincPulse) return GenericPulse(times, amplitudes, phases, effective_time(block)) end -function split_times(block::SincPulse, t1, t2, precision) +function split_timestep(block::SincPulse, precision) max_second_derivative = π^2/3 * (amplitude(block) / lobe_duration(block))^2 - step_size = sqrt(2 * precision / max_second_derivative) - npoints = Int(div(t2 - t1, step_size, RoundUp)) - return range(t1, t2, length=npoints) + return sqrt(2 * precision / max_second_derivative) end end \ No newline at end of file diff --git a/src/components/readouts/readouts.jl b/src/components/readouts/readouts.jl index 45e3432..7271159 100644 --- a/src/components/readouts/readouts.jl +++ b/src/components/readouts/readouts.jl @@ -2,9 +2,9 @@ module Readouts include("ADCs.jl") include("single_readouts.jl") -import ..AbstractTypes: ReadoutComponent, split_times +import ..AbstractTypes: ReadoutComponent, split_timestep import .ADCs: ADC, readout_times import .SingleReadouts: SingleReadout -split_times(rc::ReadoutComponent, t1, t2, precision) = [t1, t2] +split_times(rc::ReadoutComponent, precision) = Inf end \ No newline at end of file diff --git a/src/containers/linearise.jl b/src/containers/linearise.jl index 5ca43bd..86ac78c 100644 --- a/src/containers/linearise.jl +++ b/src/containers/linearise.jl @@ -1,6 +1,6 @@ module Linearise import StaticArrays: SVector -import ...Components: GradientWaveform, split_times +import ...Components: GradientWaveform, split_timestep import ...Variables: amplitude, phase, gradient_strength3, duration import ..Abstract: edge_times, start_time, end_time import ..BaseSequences: BaseSequence, Sequence @@ -45,15 +45,18 @@ function linearise(bb::BaseBuildingBlock; precision=0.01) tmean = (t1 + t2) / 2 # determine where to split this domain - possible_splits = Any[[t1, t2], ] + timestep = Inf for key in keys(bb) if !(start_time(bb, key) < tmean < end_time(bb, key)) continue end - block = bb[key] - append!(possible_splits, start_time(bb, key) .+ split_times(block, t1 - start_time(bb, key), t2 - start_time(bb, key), precision)) + new_timestep = split_timestep(block, precision) + if new_timestep < timestep + timestep = new_timestep + end end - tsplits = argmax(length, possible_splits) + nsteps = div(t2 - t1, timestep, RoundUp) + tsplits = range(t1, t2, length=nsteps) for (t1b, t2b) in zip(tsplits[1:end-1], tsplits[2:end]) tmeanb = (t1b + t2b) / 2 -- GitLab