From 674555766d26d8ee03e78368238ca9d59dcf973f Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Thu, 21 Mar 2024 11:59:07 +0000
Subject: [PATCH] Add instant pulses/gradients to linearised sequence

---
 src/containers/abstract.jl  |  4 +--
 src/containers/linearise.jl | 68 +++++++++++++++++++++++++++++--------
 2 files changed, 55 insertions(+), 17 deletions(-)

diff --git a/src/containers/abstract.jl b/src/containers/abstract.jl
index 2fd3d82..5e1ac8c 100644
--- a/src/containers/abstract.jl
+++ b/src/containers/abstract.jl
@@ -115,7 +115,7 @@ iter_blocks(container::ContainerBlock) = iter(container, Val(:block))
 Returns all the [`InstantPulse`](@ref) within the sequence with their timings
 """
 iter_instant_pulses(container::ContainerBlock) = iter(container, Val(:instantpulse))
-iter(component::Tuple{<:Number, <:InstantPulse}, ::Val{:instantpulse}) = [(0., component)]
+iter(component::Tuple{<:Number, <:InstantPulse}, ::Val{:instantpulse}) = [component]
 
 """
     iter_instant_gradients(sequence)
@@ -123,6 +123,6 @@ iter(component::Tuple{<:Number, <:InstantPulse}, ::Val{:instantpulse}) = [(0., c
 Returns all the [`InstantGradient`](@ref) within the sequence with their timings
 """
 iter_instant_gradients(container::ContainerBlock) = iter(container, Val(:instantgradient))
-iter(component::Tuple{<:Number, <:InstantPulse}, ::Val{:instantgradient}) = [(0., component)]
+iter(component::Tuple{<:Number, <:InstantGradient}, ::Val{:instantgradient}) = [component]
 
 end
\ No newline at end of file
diff --git a/src/containers/linearise.jl b/src/containers/linearise.jl
index 1aae063..79f0658 100644
--- a/src/containers/linearise.jl
+++ b/src/containers/linearise.jl
@@ -1,8 +1,8 @@
 module Linearise
 import StaticArrays: SVector
-import ...Components: GradientWaveform, split_timestep
+import ...Components: GradientWaveform, split_timestep, InstantPulse, InstantGradient3D
 import ...Variables: amplitude, phase, gradient_strength3, duration
-import ..Abstract: edge_times, start_time, end_time, ContainerBlock
+import ..Abstract: edge_times, start_time, end_time, ContainerBlock, iter_instant_gradients, iter_instant_pulses
 import ..BaseSequences: BaseSequence, Sequence
 import ..BuildingBlocks: BaseBuildingBlock
 
@@ -16,11 +16,9 @@ end
 """
     SequencePart(sequence, time1, time2)
 
-Represents the time between `time1` and `time2` of a larger [`Sequence`](@ref)
+Represents the time between `time1` and `time2` of a larger [`LinearSequence`](@ref)
 
 The gradient, RF amplitude, and RF phase are all be modeled as changing linearly during this time.
-
-See [`linearise`](@ref) to split a sequence into such linear parts.
 """
 struct SequencePart
     gradient :: LinearPart{SVector{3, Float64}}
@@ -32,6 +30,11 @@ end
 
 
 function SequencePart(sequence::BaseSequence{N}, time1::Number, time2::Number) where {N}
+    tmean = (time1 + time2) / 2
+    nTR = div(tmean, duration(sequence), RoundDown)
+    time1 -= nTR * duration(sequence)
+    time2 -= nTR * duration(sequence)
+    tmean -= nTR * duration(sequence)
     if -1e-9 < time1 < 0.
         time1 = 0.
     end
@@ -41,7 +44,6 @@ function SequencePart(sequence::BaseSequence{N}, time1::Number, time2::Number) w
     if !(0 <= time1 <= time2 <= duration(sequence))
         error("Sequence timings are out of bound")
     end
-    tmean = (time1 + time2) / 2
     for key in 1:N
         if (end_time(sequence, key) > tmean)
             return SequencePart(sequence[key], time1 - start_time(sequence, key), time2 - start_time(sequence, key))
@@ -94,11 +96,11 @@ The split times will include any time when (for any of the provided sequences):
 
 Continuous gradient waveforms or RF pulses might be split up further to ensure the linear approximations meet the required `precision` (see [`split_timestep`](@ref)).
 """
-split_times(sequence::BaseSequence; kwargs...) = split_times([sequence]; kwargs...)
+split_times(sequence::BaseSequence, args...; kwargs...) = split_times([sequence], args...; kwargs...)
 split_times(sequences::AbstractVector{<:BaseSequence}; kwargs...) = split_times(sequences, 0., maximum(duration.(sequences)); kwargs...)
 
 function split_times(sequences::AbstractVector{<:BaseSequence}, tstart::Number, tfinal::Number; precision=0.01, max_timestep=Inf)
-    edges = [tstart, tfinal]
+    edges = Float64.([tstart, tfinal])
     for sequence in sequences
         raw_edges = edge_times(sequence)
         nTR_start = Int(div(tstart, duration(sequence), RoundDown))
@@ -138,16 +140,52 @@ end
 
 
 """
-    linearise(sequence(s), times)
-    linearise(sequence(s), tstart, tfinal; max_timestep=Inf, precision=0.01)
+    LinearSequence(sequence, times)
+    LinearSequence(sequence; max_timestep=Inf, precision=0.01)
+    LinearSequence(sequence, time1, time2; max_timestep=Inf, precision=0.01)
+
+A piece-wise linear approximation of a sequence.
 
-Splits any [`Sequence`](@ref) into a series of [`SequencePart`](@ref) objects where the gradients/pulses are approximated to be linear.
+By default it represents the sequence between `time1=0` and `time2=TR`
 
-If the `times` are not explicitly set they will be obtained from [`split_times`](@ref) (using the values of `max_timestep` and `precision`).
+The gradient, RF amplitude, and RF phase are all be modeled as changing linearly during this time.
+
+If multiple sequences are provided a vector of `LinearSequence` objects are returned, all of which are split at the same time.
 """
-linearise(container::Union{BaseSequence, AbstractVector{<:BaseSequence}}, tstart::Number, tfinal::Number; kwargs...) = linearise(container, split_times(container, tstart, tfinal; kwargs...))
-linearise(containers::AbstractVector{<:BaseSequence}, times::AbstractVector{<:Number}) = [linearise(c, times) for c in containers]
-linearise(container::BaseSequence, times::AbstractVector{<:Number}) = [SequencePart(container, t1, t2) for (t1, t2) in zip(times[1:end-1], times[2:end])]
+struct LinearSequence
+    finite_parts :: Vector{SequencePart}
+    instant_pulses :: Vector{Tuple{Int, InstantPulse}}
+    instant_gradient :: Vector{Tuple{Int, InstantGradient3D}}
+end
+
+LinearSequence(sequence::BaseSequence; kwargs...) = LinearSequence(sequence, 0., duration(sequence); kwargs...)
+LinearSequence(container::Union{BaseSequence, AbstractVector{<:BaseSequence}}, tstart::Number, tfinal::Number; kwargs...) = LinearSequence(container, split_times(container, tstart, tfinal; kwargs...))
+LinearSequence(containers::AbstractVector{<:BaseSequence}, times::AbstractVector{<:Number}) = map(c -> LinearSequence(c, times), containers)
+function LinearSequence(container::BaseSequence, times::AbstractVector{<:Number}) 
+    parts = [SequencePart(container, t1, t2) for (t1, t2) in zip(times[1:end-1], times[2:end])]
+
+    tstart = times[1]
+    tfinal = times[end]
+
+    pulses = Tuple{Int, InstantPulse}[]
+    gradients = Tuple{Int, InstantGradient3D}[]
+    for nTR in div(tstart, duration(container), RoundDown):div(tfinal, duration(container), RoundUp)
+        for (to_store, func) in [
+            (pulses, iter_instant_pulses),
+            (gradients, iter_instant_gradients),
+        ]
+            for (time, pulse) in func(container)
+                real_time = time + nTR * duration(container)
+                if !(tstart <= real_time < tfinal)
+                    continue
+                end
+                index = findmin(t -> abs(t - real_time), times)[2]-1 
+                push!(to_store, (index, pulse))
+            end
+        end
+    end
+    return LinearSequence(parts, pulses, gradients)
+end
 
 
 end
\ No newline at end of file
-- 
GitLab