From 550393a8b868e5d3dc7038f8f5e90168caa2e7b1 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Wed, 21 Feb 2024 14:39:25 +0000
Subject: [PATCH] Match trapezoid timings by default

---
 src/parts/helper_functions.jl | 22 +++++++++++++++++-----
 1 file changed, 17 insertions(+), 5 deletions(-)

diff --git a/src/parts/helper_functions.jl b/src/parts/helper_functions.jl
index 7afca36..281d648 100644
--- a/src/parts/helper_functions.jl
+++ b/src/parts/helper_functions.jl
@@ -8,7 +8,7 @@ import ..EPIReadouts: EPIReadout
 import ...BuildSequences: global_model, build_sequence
 import ...Containers: Sequence
 import ...Components: SincPulse, ConstantPulse, InstantPulse, SingleReadout, InstantGradient
-import ...Variables: qvec, flat_time, rise_time, qval, apply_simple_constraint!
+import ...Variables: qvec, flat_time, rise_time, qval, apply_simple_constraint!, variables
 
 
 function _get_pulse(shape, flip_angle, phase, frequency, Nzeros, group, bandwidth, duration)
@@ -141,11 +141,11 @@ Adds a readout event to the sequence.
 - `optimise`: Whether to optimise this readout event in isolation from the rest of the sequence. Use this with caution. It can speed up the optimisation (and for very complicated sequences make it more robust), however the resulting parameters might not represent the optimal solution of any external constraints (which are ignored if the readout is optimised in isolation).
 - `scanner`: Used for testing. Do not set this parameter at this level (instead set it for the total sequence using [`build_sequence`](@ref)).
 """
-function readout_event(; type=:epi, optimise=false, scanner=nothing, variables...)
+function readout_event(; type=:epi, optimise=false, scanner=nothing, all_variables...)
     if type == :instant
         optimise = false # there is nothing to optimise
     end
-    real_variables = Dict(key => value for (key, value) in pairs(variables) if !(isnothing(value) || (value isa AbstractVector && all(isnothing.(value)))))
+    real_variables = Dict(key => value for (key, value) in pairs(all_variables) if !(isnothing(value) || (value isa AbstractVector && all(isnothing.(value)))))
     build_sequence(scanner; optimise=optimise) do 
         func_dict = Dict(
             :epi => EPIReadout,
@@ -171,8 +171,8 @@ Returns two DWI gradients that are guaranteed to cancel each other out.
 - `optimise`: Whether to optimise this readout event in isolation from the rest of the sequence. Use this with caution. It can speed up the optimisation (and for very complicated sequences make it more robust), however the resulting parameters might not represent the optimal solution of any external constraints (which are ignored if the readout is optimised in isolation).
 - `scanner`: Used for testing. Do not set this parameter at this level (instead set it for the total sequence using [`build_sequence`](@ref)).
 """
-function dwi_gradients(; type=:trapezoid, optimise=false, scanner=nothing, refocus=true, orientation=[1, 0, 0], group=:DWI, variables...)
-    real_variables = Dict(key => value for (key, value) in pairs(variables) if !(isnothing(value) || (value isa AbstractVector && all(isnothing.(value)))))
+function dwi_gradients(; type=:trapezoid, optimise=false, scanner=nothing, refocus=true, orientation=[1, 0, 0], group=:DWI, match=nothing, all_variables...)
+    real_variables = Dict(key => value for (key, value) in pairs(all_variables) if !(isnothing(value) || (value isa AbstractVector && all(isnothing.(value)))))
 
     func_dict = Dict(
         :trapezoid => Trapezoid,
@@ -181,6 +181,12 @@ function dwi_gradients(; type=:trapezoid, optimise=false, scanner=nothing, refoc
     if !(type in keys(func_dict))
         error("DWI gradients type `$type` has not been implemented. Please use one of $(keys(func_dict)).")
     end
+
+    if isnothing(match)
+        match = get(Dict(
+            :trapezoid => [:rise_time, :flat_time, :slew_rate],
+        ), type, [])
+    end
     build_sequence(scanner; optimise=optimise) do 
         other_orientation = isnothing(orientation) ? nothing : (refocus ? orientation : -orientation)
         (g1, g2) = (
@@ -192,6 +198,12 @@ function dwi_gradients(; type=:trapezoid, optimise=false, scanner=nothing, refoc
         else
             apply_simple_constraint!(qval(g1), -qval(g2))
         end
+        for var_func in match
+            if var_func isa Symbol
+                var_func = variables[var_func]
+            end
+            apply_simple_constraint!(var_func(g1), var_func(g2))
+        end
         return (g1, g2)
     end
 end
-- 
GitLab