From bd39a767a7f1001e67e01214ecbe5af4284bd3e4 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Thu, 1 Feb 2024 14:55:22 +0000
Subject: [PATCH] Set scanner at a global level

---
 src/build_sequences.jl                 | 17 ++++++++++++++---
 src/overlapping/generic.jl             |  6 ++++++
 src/overlapping/trapezoid_gradients.jl |  2 ++
 src/scanners.jl                        |  5 +++--
 src/sequences.jl                       | 15 +++++----------
 5 files changed, 30 insertions(+), 15 deletions(-)

diff --git a/src/build_sequences.jl b/src/build_sequences.jl
index df993dc..5182ce9 100644
--- a/src/build_sequences.jl
+++ b/src/build_sequences.jl
@@ -2,10 +2,14 @@ module BuildSequences
 import JuMP: Model, optimizer_with_attributes, optimize!
 import Ipopt
 import Juniper
+import ..Scanners: Scanner, scanner_constraints!
+import ..BuildingBlocks: BuildingBlock
 
 const GLOBAL_MODEL = Ref(Model())
 const IGNORE_MODEL = GLOBAL_MODEL[]
 
+const GLOBAL_SCANNER = Ref(Scanner())
+
 """
 Wrapper to build a sequence.
 
@@ -31,23 +35,28 @@ You can also add any arbitrary constraints or objectives using the same syntax a
 
 As soon as the code block is the optimal sequence matching all your constraints and objectives will be returned.
 """
-function build_sequence(f::Function, model::Model)
+function build_sequence(f::Function, scanner::Scanner, model::Model)
     prev_model = GLOBAL_MODEL[]
     GLOBAL_MODEL[] = model
+    prev_scanner = GLOBAL_SCANNER[]
+    if !isnothing(scanner)
+        GLOBAL_SCANNER[] = scanner
+    end
     try
         sequence = f(model)
         optimize!(model)
         return sequence
     finally
         GLOBAL_MODEL[] = prev_model
+        GLOBAL_SCANNER[] = prev_scanner
     end
 end
 
-function build_sequence(f::Function)
+function build_sequence(f::Function, scanner::Scanner)
     ipopt_opt = optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 3)
     juniper_opt = optimizer_with_attributes(Juniper.Optimizer, "nl_solver" => ipopt_opt)
     model = Model(juniper_opt)
-    build_sequence(f, model)
+    build_sequence(f, scanner, model)
 end
 
 
@@ -58,6 +67,8 @@ function get_global_model()
     return GLOBAL_MODEL[]
 end
 
+scanner_constraints!(bb::BuildingBlock)  = scanner_constraints!(bb, GLOBAL_SCANNER[])
+
 
 """
     @global_model_constructor BuildingBlockType
diff --git a/src/overlapping/generic.jl b/src/overlapping/generic.jl
index 2874db5..433437f 100644
--- a/src/overlapping/generic.jl
+++ b/src/overlapping/generic.jl
@@ -5,6 +5,7 @@ import ...Wait: WaitBlock
 import ...Readouts: InstantReadout
 import ...Pulses: RFPulseBlock
 import ...Gradients: GradientBlock
+import ...Scanners: scanner_constraints!
 
 """
     GenericWaveform(duration, waveform, interruptions)
@@ -20,6 +21,11 @@ struct GenericOverlapping <: AbstractOverlapping
     duration :: VariableType
     waveform :: Vector{Union{WaitBlock, GradientBlock}}
     interruptions :: Vector{NamedTuple{(:index, :time, :block), Tuple{Int64, <:VariableType, <:Union{RFPulseBlock, InstantReadout}}}}
+    function GenericOverlapping(duration::VariableType, waveform::AbstractVector, interruptions::AbstractVector=[])
+        res = new(duration, waveform, interruptions)
+        scanner_constraints!.(waveform)
+        return res
+    end
 end
 
 
diff --git a/src/overlapping/trapezoid_gradients.jl b/src/overlapping/trapezoid_gradients.jl
index 39fb32e..f5c1f91 100644
--- a/src/overlapping/trapezoid_gradients.jl
+++ b/src/overlapping/trapezoid_gradients.jl
@@ -10,6 +10,7 @@ import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, v
 import ...BuildingBlocks: duration, set_simple_constraints!, fixed, RFPulseBlock
 import ...BuildSequences: @global_model_constructor
 import ...Gradients: ChangingGradientBlock, ConstantGradientBlock
+import ...Scanners: scanner_constraints!
 import ..Abstract: interruptions, waveform, AbstractOverlapping
 
 
@@ -112,6 +113,7 @@ function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing,
     set_simple_constraints!(model, res, kwargs)
     @constraint model flat_time >= 0
     @constraint model rise_time >= 0
+    scanner_constraints!(res)
     return res
 end
 
diff --git a/src/scanners.jl b/src/scanners.jl
index e707a0d..b5ddf9a 100644
--- a/src/scanners.jl
+++ b/src/scanners.jl
@@ -4,7 +4,8 @@ Define general [`Scanner`](@ref) type and methods as well as some concrete scann
 module Scanners
 import JuMP: Model, @constraint, owner_model
 import ..Variables: gradient_strength, slew_rate
-import ..BuildingBlocks: BuildingBlock, get_children_blocks
+import ..BuildingBlocks: BuildingBlock, get_children_blocks, ContainerBlock
+import ..Variables: variables
 
 const gyromagnetic_ratio = 42576.38476  # (kHz/T)
 
@@ -88,7 +89,7 @@ predefined_scanners = Dict(
 )
 
 """
-    scanner_constraints!([model, ]building_block, scanner)
+    scanner_constraints!([model, ]building_block[, scanner])
 
 Adds any constraints from a specific scanner to a [`BuildingBlock`]{@ref}.
 """
diff --git a/src/sequences.jl b/src/sequences.jl
index 37a7af1..e6f2947 100644
--- a/src/sequences.jl
+++ b/src/sequences.jl
@@ -6,16 +6,14 @@ import Printf: @sprintf
 import JuMP: Model, @constraint
 import ...BuildSequences: @global_model_constructor
 import ...Variables: variables, start_time, duration, VariableType, get_free_variable, TR, end_time
-import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, scanner_constraints!, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks
+import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks
 
 """
-    Sequence(building_blocks...; TR=nothing, scanner=nothing)
-    Sequence([building_blocks]; TR=nothing, scanner=nothing)
+    Sequence(building_blocks...; TR=nothing)
+    Sequence([building_blocks]; TR=nothing)
 
 Represents a series of [`BuildingBlock`](@ref) objects run in turn.
 
-Providing a [`Scanner`](@ref) will lead to [`scanner_constraints!`](@ref) to be called on all building blocks.
-
 This can be used as a top-level NMR/MRI sequence (in which case the [`TR`](@ref) variable is relevant)
 or be embedded as a [`BuildingBlock`](@ref) into higher-order `Sequence` or other [`ContainerBlock`](@ref) objects.
 
@@ -26,7 +24,7 @@ struct Sequence <: ContainerBlock
     model :: Model
     _blocks :: Vector{<:BuildingBlock}
     TR :: VariableType
-    function Sequence(model::Model, blocks::AbstractVector; TR=nothing, scanner=nothing)
+    function Sequence(model::Model, blocks::AbstractVector; TR=nothing)
         seq = new(
             model,
             to_block.(blocks),
@@ -35,16 +33,13 @@ struct Sequence <: ContainerBlock
         if !(TR isa Number && isinf(TR))
             @constraint model seq.TR >= duration(seq)
         end
-        if !isnothing(scanner)
-            scanner_constraints!(model, seq, scanner)
-        end
         return seq
     end
 end
 
 @global_model_constructor Sequence
 
-Sequence(model::Model, blocks...; TR=nothing, scanner=nothing) = Sequence(model, [blocks...]; TR=TR, scanner=scanner)
+Sequence(model::Model, blocks...; TR=nothing) = Sequence(model, [blocks...]; TR=TR)
 
 Base.length(seq::Sequence) = length(seq._blocks)
 Base.getindex(seq::Sequence, index) = seq._blocks[index]
-- 
GitLab