From 6fb836020356d8ac831a163201e3b52b8ad8e72d Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Fri, 2 Feb 2024 10:27:28 +0000
Subject: [PATCH] Replace model parameter with global_model

---
 src/MRIBuilder.jl                       |  8 ++--
 src/build_sequences.jl                  | 29 +++++++--------
 src/building_blocks.jl                  | 49 ++++++-------------------
 src/gradients/instant_gradients.jl      | 11 ++----
 src/helper_functions.jl                 |  6 +--
 src/overlapping/spoilt_slice_selects.jl |  8 ++--
 src/overlapping/trapezoid_gradients.jl  | 32 +++++++---------
 src/pulses/constant_pulses.jl           | 16 +++-----
 src/pulses/instant_pulses.jl            | 16 +++-----
 src/pulses/sinc_pulses.jl               | 19 ++++------
 src/scanners.jl                         | 10 ++---
 src/sequences.jl                        | 16 +++-----
 src/variables.jl                        | 31 +++++++---------
 13 files changed, 95 insertions(+), 156 deletions(-)

diff --git a/src/MRIBuilder.jl b/src/MRIBuilder.jl
index c449a35..11be175 100644
--- a/src/MRIBuilder.jl
+++ b/src/MRIBuilder.jl
@@ -16,8 +16,8 @@ include("sequences.jl")
 include("pathways.jl")
 include("helper_functions.jl")
 
-import .BuildSequences: build_sequence
-export build_sequence
+import .BuildSequences: build_sequence, global_model, global_scanner
+export build_sequence, global_model, global_scanner
 
 import .Scanners: Scanner, B0, Siemens_Connectom, Siemens_Prisma, Siemens_Terra
 export Scanner, B0, Siemens_Connectom, Siemens_Prisma, Siemens_Terra
@@ -52,7 +52,7 @@ export Pathway, duration_transverse, duration_dephase, bval, bmat
 import .HelperFunctions: excitation_pulse, refocus_pulse
 export excitation_pulse, refocus_pulse
 
-import JuMP: @constraint, @objective, objective_function, optimize!, has_values, value, owner_model, Model
-export @constraint, @objective, objective_function, optimize!, has_values, value, owner_model, Model
+import JuMP: @constraint, @objective, objective_function, value, Model
+export @constraint, @objective, objective_function, value, Model
 
 end
diff --git a/src/build_sequences.jl b/src/build_sequences.jl
index 70dfdd3..41deba6 100644
--- a/src/build_sequences.jl
+++ b/src/build_sequences.jl
@@ -3,7 +3,7 @@ import JuMP: Model, optimizer_with_attributes, optimize!
 import Ipopt
 import Juniper
 import ..Scanners: Scanner, scanner_constraints!
-import ..BuildingBlocks: BuildingBlock
+import ..BuildingBlocks: BuildingBlock, apply_simple_constraint!, match_blocks!
 
 const GLOBAL_MODEL = Ref(Model())
 const IGNORE_MODEL = GLOBAL_MODEL[]
@@ -60,29 +60,26 @@ function build_sequence(f::Function, scanner::Scanner)
 end
 
 
-function get_global_model()
+function global_model()
     if GLOBAL_MODEL[] == IGNORE_MODEL
         error("No global model has been set. Please explicitly set one in the constructor or set a global model using `set_model`.")
     end
     return GLOBAL_MODEL[]
 end
 
-scanner_constraints!(bb::BuildingBlock)  = scanner_constraints!(bb, GLOBAL_SCANNER[])
-
-
-"""
-    @global_model_constructor BuildingBlockType
-
-Add a constructor to the [`BuildingBlock`](@ref) subtype that fetches the global JuMP model (set by [`set_model`](@ref)) and assigns it to the first argument.
-```
-BuildingBlockType(args...; kwargs...) = BuildingBlockType(global_model::JuMP.Model, args...; kwargs...)
-```
-"""
-macro global_model_constructor(bb)
-    quote
-        $(esc(bb))(args...; kwargs...) = $(esc(bb))(get_global_model(), args...; kwargs...)
+function global_scanner()
+    if !isfinite(GLOBAL_SCANNER[].gradient)
+        error("No valid scanner has been set. Please provide one when calling `build_sequence`.")
     end
+    return GLOBAL_SCANNER[]
 end
 
 
+scanner_constraints!(bb::BuildingBlock)  = scanner_constraints!(bb, global_scanner())
+
+apply_simple_constraint!(variable, value) = apply_simple_constraint!(global_model(), variable, value)
+match_blocks!(block1::BuildingBlock, block2::BuildingBlock, property_list) = match_blocks!(global_model(), block1, block2, property_list)
+scanner_constraints!(building_block::BuildingBlock, scanner::Scanner, func::Function) = scanner_constraints!(building_block, scanner, func)
+scanner_constraints!(building_block::BuildingBlock) = scanner_constraints!(building_block, global_scanner())
+
 end
\ No newline at end of file
diff --git a/src/building_blocks.jl b/src/building_blocks.jl
index 48588f5..3c278dd 100644
--- a/src/building_blocks.jl
+++ b/src/building_blocks.jl
@@ -1,5 +1,5 @@
 module BuildingBlocks
-import JuMP: has_values, value, Model, @constraint, @objective, owner_model, objective_function, optimize!, AbstractJuMPScalar
+import JuMP: value, Model, @constraint, @objective, objective_function, AbstractJuMPScalar
 import Printf: @sprintf
 import ..Variables: variables, start_time, duration, end_time, gradient_strength, slew_rate, effective_time, VariableType, qval_square
 
@@ -167,7 +167,7 @@ function Base.show(io::IO, printer::BuildingBlockPrinter)
     for name in propertynames(block)
         ft = fieldtype(typeof(block), name)
         if (
-            ft in (VariableType, Model) ||
+            ft == VariableType ||
             (ft <: AbstractVector && eltype(ft) == VariableType) ||
             string(name)[1] == '_'
         )
@@ -197,15 +197,15 @@ end
 
 
 """
-    set_simple_constraints!(model, block, kwargs)
+    set_simple_constraints!(block, kwargs)
 
-Add any constraints or objective functions to the variables of a [`BuildingBlock`](@ref) in the JuMP `model`.
+Add any constraints or objective functions to the variables of a [`BuildingBlock`](@ref).
 
 Each keyword argument has to match one of the functions in [`variables`](@ref)(block).
 If set to a numeric value, a constraint will be added to fix the function value to that numeric value.
 If set to `:min` or `:max`, minimising or maximising this function will be added to the cost function.
 """
-function set_simple_constraints!(model::Model, block::BuildingBlock, kwargs)
+function set_simple_constraints!(block::BuildingBlock, kwargs)
     to_funcs = Dict(nameof(fn) => fn for fn in variables(block))
 
     invert_value(value::VariableType) = 1 / value
@@ -217,14 +217,14 @@ function set_simple_constraints!(model::Model, block::BuildingBlock, kwargs)
 
     for (key, value) in kwargs
         if key in keys(to_funcs)
-            apply_simple_constraint!(model, to_funcs[key](block), value)
+            apply_simple_constraint!(to_funcs[key](block), value)
         else
             if key == :qval
-                apply_simple_constraint!(model, to_funcs[:qval_square](block), value isa VariableType ? value^2 : value)
+                apply_simple_constraint!(to_funcs[:qval_square](block), value isa VariableType ? value^2 : value)
             elseif key == :slice_thickness && :inverse_slice_thickness in keys(to_funcs)
-                apply_simple_constraint!(model, to_funcs[:inverse_slice_thickness](block), invert_value(value))
+                apply_simple_constraint!(to_funcs[:inverse_slice_thickness](block), invert_value(value))
             elseif key == :bandwidth && :inverse_bandwidth in keys(to_funcs)
-                apply_simple_constraint!(model, to_funcs[:inverse_bandwidth](block), invert_value(value))
+                apply_simple_constraint!(to_funcs[:inverse_bandwidth](block), invert_value(value))
             else
                 error("Trying to set an unrecognised variable $key.")
             end
@@ -234,7 +234,7 @@ function set_simple_constraints!(model::Model, block::BuildingBlock, kwargs)
 end
 
 """
-    apply_simple_constraint!(model, variable, value)
+    apply_simple_constraint!([model, ]variable, value)
 
 Add a single constraint or objective to the JuMP `model`.
 This is an internal function used by [`set_simple_constraints`](@ref).
@@ -253,9 +253,7 @@ apply_simple_constraint!(model::Model, variable::AbstractVector, value::Abstract
 Matches the listed variables between two [`BuildingBlock`](@ref) objects.
 By default all shared variables (i.e., those with the same name) are matched.
 """
-function match_blocks!(block1::BuildingBlock, block2::BuildingBlock, property_list)
-    model = owner_model(block1)
-    @assert model == owner_model(block2)
+function match_blocks!(model::Model, block1::BuildingBlock, block2::BuildingBlock, property_list)
     for fn in property_list
         @constraint model fn(block1) == fn(block2)
     end
@@ -266,29 +264,4 @@ function match_blocks!(block1::BuildingBlock, block2::BuildingBlock)
     match_blocks!(block1, block2, property_list)
 end
 
-
-optimize!(bb::BuildingBlock) = optimize!(owner_model(bb))
-function owner_model(bb::BuildingBlock)
-    if hasproperty(bb, :model)
-        return bb.model
-    else
-        for name in propertynames(bb)
-            value = getproperty(bb, name)
-            if value isa AbstractJuMPScalar
-                return owner_model(value)
-            end
-        end
-    end
-    error("Cannot find owner model")
-end
-
-function has_values(bb::BuildingBlock) 
-    try 
-        return has_values(owner_model(bb))
-    catch
-        # return true for building blocks without a model
-        return true
-    end
-end
-
 end
\ No newline at end of file
diff --git a/src/gradients/instant_gradients.jl b/src/gradients/instant_gradients.jl
index 1586625..8658b38 100644
--- a/src/gradients/instant_gradients.jl
+++ b/src/gradients/instant_gradients.jl
@@ -1,9 +1,9 @@
 module InstantGradients
 import StaticArrays: SVector
-import JuMP: @constraint, @variable, Model, owner_model, AbstractJuMPScalar
+import JuMP: @constraint, @variable, AbstractJuMPScalar
 import ...Variables: qvec, bmat_gradient, duration, variables, get_free_variable, VariableType
 import ...BuildingBlocks: GradientBlock, fixed
-import ...BuildSequences: @global_model_constructor
+import ...BuildSequences: global_model
 
 """
     InstantGradientBlock(; orientation=nothing, qval=nothing, qvec=nothing, rotate=nothing, scale=nothing)
@@ -20,14 +20,11 @@ Defines an instantaneous gradient.
 - [`qval`](@ref): Spatial scale on which spins will be dephased due to this pulsed gradient in rad/um.
 """
 struct InstantGradientBlock <: GradientBlock
-    model::Model
     qvec :: SVector{3, VariableType}
     rotate :: Union{Nothing, Symbol}
     scale :: Union{Nothing, Symbol}
 end
 
-@global_model_constructor InstantGradientBlock
-
 function interpret_orientation(orientation, qval, qvec, will_rotate)
     if !isnothing(qvec)
         return (false, qvec)
@@ -60,10 +57,10 @@ function interpret_orientation(orientation, qval, qvec, will_rotate)
     end
 end
 
-function InstantGradientBlock(model::Model; orientation=nothing, qval=nothing, qvec=nothing, rotate=nothing, scale=nothing)
+function InstantGradientBlock(; orientation=nothing, qval=nothing, qvec=nothing, rotate=nothing, scale=nothing)
+    model = global_model()
     (used_qval, qvec) = interpret_orientation(orientation, qval, qvec, !isnothing(rotate))
     res = InstantGradientBlock(
-        model,
         qvec,
         rotate,
         scale
diff --git a/src/helper_functions.jl b/src/helper_functions.jl
index a170638..9be264f 100644
--- a/src/helper_functions.jl
+++ b/src/helper_functions.jl
@@ -1,6 +1,6 @@
 module HelperFunctions
 import JuMP: @constraint
-import ..BuildSequences: get_global_model
+import ..BuildSequences: global_model
 import ..Sequences: Sequence
 import ..Pulses: SincPulse, ConstantPulse, InstantRFPulseBlock
 import ..Overlapping: TrapezoidGradient, SpoiltSliceSelect
@@ -65,7 +65,7 @@ function excitation_pulse(; flip_angle=90, phase=0., frequency=0., shape=:sinc,
         TrapezoidGradient(orientation=[0, 0, -1.], rotate=rotate_grad, duration=:min);
         TR=Inf
     )
-    @constraint get_global_model() qvec(grad, 1, nothing)[3] == -qvec(seq[2])[3]
+    @constraint global_model() qvec(grad, 1, nothing)[3] == -qvec(seq[2])[3]
     return seq
 end
 
@@ -112,7 +112,7 @@ function refocus_pulse(; flip_angle=180, phase=0., frequency=0., shape=:sinc, sl
         orientation=isnothing(rotate_grad) ? [1, 1, 1] : [0, 0, 1]
         if isinf(slice_thickness)
             grad = TrapezoidGradient(orientation=orientation, duration=:min, rotate=rotate_grad)
-            @constraint get_global_model() qvec(grad)[3] == 2Ï€ * 1e-3 / spoiler
+            @constraint global_model() qvec(grad)[3] == 2Ï€ * 1e-3 / spoiler
             return Sequence(grad, pulse, grad; TR=Inf)
         else
             res = SpoiltSliceSelect(pulse; orientation=orientation, duration=:min, rotate=rotate_grad, slice_thickness=slice_thickness, spoiler_scale=spoiler)
diff --git a/src/overlapping/spoilt_slice_selects.jl b/src/overlapping/spoilt_slice_selects.jl
index 1cbe4a5..e218625 100644
--- a/src/overlapping/spoilt_slice_selects.jl
+++ b/src/overlapping/spoilt_slice_selects.jl
@@ -4,7 +4,7 @@ import LinearAlgebra: norm
 import StaticArrays: SVector
 import JuMP: @variable, @constraint, @objective, objective_function
 import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!
-import ...BuildSequences: GLOBAL_SCANNER, get_global_model
+import ...BuildSequences: global_model, global_scanner
 import ...Variables: VariableType, variables, duration, rise_time, flat_time, effective_time, qvec, gradient_strength, slew_rate
 import ...Gradients: ChangingGradientBlock, ConstantGradientBlock
 import ..Abstract: interruptions, waveform, AbstractOverlapping
@@ -37,7 +37,7 @@ struct SpoiltSliceSelect <: AbstractOverlapping
 end
 
 function SpoiltSliceSelect(pulse; orientation=[0, 0, 1], rotate=nothing, spoiler_scale=nothing, kwargs...)
-    model = get_global_model()
+    model = global_model()
     res = SpoiltSliceSelect(
         orientation ./ norm(orientation),
         @variable(model, start=0.1),
@@ -47,7 +47,7 @@ function SpoiltSliceSelect(pulse; orientation=[0, 0, 1], rotate=nothing, spoiler
         @variable(model, start=0.1),
         @variable(model, start=0.1),
         rotate,
-        slew_rate(GLOBAL_SCANNER[])
+        slew_rate(global_scanner())
     )
     for time_var in (res.rise_time1, res.flat_time1, res.diff_time, res.flat_time2, res.fall_time2)
         @constraint model time_var >= 0
@@ -68,7 +68,7 @@ function SpoiltSliceSelect(pulse; orientation=[0, 0, 1], rotate=nothing, spoiler
     end
     set_simple_constraints!(model, res, kwargs)
 
-    max_time = gradient_strength(GLOBAL_SCANNER[]) / slew_rate(GLOBAL_SCANNER[])
+    max_time = gradient_strength(global_scanner()) / res.slew_rate
     @constraint model rise_time(res)[1] <= max_time
     @constraint model fall_time(res)[2] <= max_time
     return res
diff --git a/src/overlapping/trapezoid_gradients.jl b/src/overlapping/trapezoid_gradients.jl
index b199e49..5f86100 100644
--- a/src/overlapping/trapezoid_gradients.jl
+++ b/src/overlapping/trapezoid_gradients.jl
@@ -3,12 +3,12 @@ Defines a set of different options for MRI gradients.
 """
 module TrapezoidGradients
 
-import JuMP: @constraint, @variable, Model, VariableRef, owner_model, value
+import JuMP: @constraint, @variable, VariableRef, value
 import StaticArrays: SVector
 import LinearAlgebra: norm
 import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType, inverse_slice_thickness, inverse_bandwidth, effective_time
 import ...BuildingBlocks: duration, set_simple_constraints!, fixed, RFPulseBlock
-import ...BuildSequences: @global_model_constructor
+import ...BuildSequences: global_model
 import ...Gradients: ChangingGradientBlock, ConstantGradientBlock
 import ...Scanners: scanner_constraints!
 import ..Abstract: interruptions, waveform, AbstractOverlapping
@@ -43,7 +43,6 @@ Any variables defined for the specific pulse added. Also:
 The [`bvalue`](@ref) can be constrained for multiple gradient pulses.
 """
 struct TrapezoidGradient <: AbstractOverlapping
-    model :: Model
     rise_time :: VariableType
     flat_time :: VariableType
     slew_rate :: SVector{3, VariableType}
@@ -54,21 +53,19 @@ struct TrapezoidGradient <: AbstractOverlapping
     time_after_pulse :: VariableType
 end
 
-@global_model_constructor TrapezoidGradient
-
-function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing, flat_time=nothing, rotate=nothing, scale=nothing, pulse=nothing,kwargs...)
+function TrapezoidGradient(orientation=nothing, rise_time=nothing, flat_time=nothing, rotate=nothing, scale=nothing, pulse=nothing,kwargs...)
     if isnothing(orientation) && isnothing(rotate)
         rate_1d = nothing
         slew_rate = (
-            get_free_variable(model, nothing),
-            get_free_variable(model, nothing),
-            get_free_variable(model, nothing),
+            get_free_variable(nothing),
+            get_free_variable(nothing),
+            get_free_variable(nothing),
         )
     else
-        rate_1d = get_free_variable(model, nothing)
-        @constraint model rate_1d >= 0
+        rate_1d = get_free_variable(nothing)
+        @constraint global_model() rate_1d >= 0
         if isnothing(orientation)
-            rate_1d = get_free_variable(model, nothing)
+            rate_1d = get_free_variable(nothing)
             slew_rate = (
                 rate_1d,
                 0.,
@@ -86,9 +83,9 @@ function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing,
         end
     end
     slew_rate = SVector{3}(slew_rate)
-    rise_time = get_free_variable(model, rise_time)
+    rise_time = get_free_variable(rise_time)
     if isnothing(pulse)
-        flat_time = get_free_variable(model, flat_time)
+        flat_time = get_free_variable(flat_time)
         time_before_pulse = time_after_pulse = 0.
     elseif pulse isa RFPulseBlock
         flat_time = duration(pulse)
@@ -99,7 +96,6 @@ function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing,
     end
 
     res = TrapezoidGradient(
-        model,
         rise_time,
         flat_time,
         slew_rate,
@@ -110,9 +106,9 @@ function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing,
         time_after_pulse
     )
 
-    set_simple_constraints!(model, res, kwargs)
-    @constraint model flat_time >= 0
-    @constraint model rise_time >= 0
+    set_simple_constraints!(res, kwargs)
+    @constraint flat_time >= 0
+    @constraint rise_time >= 0
     scanner_constraints!(res)
     return res
 end
diff --git a/src/pulses/constant_pulses.jl b/src/pulses/constant_pulses.jl
index 23e73b9..cf069cb 100644
--- a/src/pulses/constant_pulses.jl
+++ b/src/pulses/constant_pulses.jl
@@ -1,8 +1,8 @@
 module ConstantPulses
-import JuMP: VariableRef, @constraint, @variable, value, Model
+import JuMP: VariableRef, @constraint, @variable, value
 import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!, fixed
 import ...Variables: variables, get_free_variable, flip_angle, phase, amplitude, frequency, start_time, end_time, VariableType, duration, effective_time, inverse_bandwidth
-import ...BuildSequences: @global_model_constructor
+import ...BuildSequences: global_model
 import ..FixedPulses: FixedPulse
 
 """
@@ -18,7 +18,6 @@ Represents an radio-frequency pulse with a constant amplitude and frequency (i.e
 - [`frequency`](@ref): frequency of the RF pulse relative to the Larmor frequency (in kHz).
 """
 struct ConstantPulse <: RFPulseBlock
-    model :: Model
     amplitude :: VariableType
     duration :: VariableType
     phase :: VariableType
@@ -26,17 +25,14 @@ struct ConstantPulse <: RFPulseBlock
     scale :: Union{Nothing, Symbol}
 end
 
-@global_model_constructor ConstantPulse
-
-function ConstantPulse(model::Model; amplitude=nothing, duration=nothing, phase=nothing, frequency=nothing, scale=nothing, kwargs...) 
+function ConstantPulse(amplitude=nothing, duration=nothing, phase=nothing, frequency=nothing, scale=nothing, kwargs...) 
     res = ConstantPulse(
-        model,
-        [get_free_variable(model, value) for value in (amplitude, duration, phase, frequency)]...,
+        [get_free_variable(value) for value in (amplitude, duration, phase, frequency)]...,
         scale
 
     )
-    @constraint model res.amplitude >= 0
-    set_simple_constraints!(model, res, kwargs)
+    @constraint global_model() res.amplitude >= 0
+    set_simple_constraints!(res, kwargs)
     return res
 end
 
diff --git a/src/pulses/instant_pulses.jl b/src/pulses/instant_pulses.jl
index fb49d0e..7234ce0 100644
--- a/src/pulses/instant_pulses.jl
+++ b/src/pulses/instant_pulses.jl
@@ -1,27 +1,23 @@
 module InstantPulses
-import JuMP: @constraint, @variable, VariableRef, value, Model
+import JuMP: @constraint, @variable, VariableRef, value
 import ...BuildingBlocks: RFPulseBlock, fixed
 import ...Variables: flip_angle, phase, start_time, variables, duration, get_free_variable, VariableType, effective_time, inverse_bandwidth
-import ...BuildSequences: @global_model_constructor
+import ...BuildSequences: global_model
 import ..FixedPulses: FixedInstantPulse
 
 struct InstantRFPulseBlock <: RFPulseBlock
-    model :: Model
     flip_angle :: VariableType
     phase :: VariableType
     scale :: Union{Nothing, Symbol}
 end
 
-@global_model_constructor InstantRFPulseBlock
-
-function InstantRFPulseBlock(model::Model; flip_angle=nothing, phase=nothing, scale=nothing) 
+function InstantRFPulseBlock(flip_angle=nothing, phase=nothing, scale=nothing) 
     res = InstantRFPulseBlock(
-        model,
-        get_free_variable(model, flip_angle),
-        get_free_variable(model, phase),
+        get_free_variable(flip_angle),
+        get_free_variable(phase),
         scale
     )
-    @constraint model res.flip_angle >= 0
+    @constraint global_model() res.flip_angle >= 0
     return res
 end
 
diff --git a/src/pulses/sinc_pulses.jl b/src/pulses/sinc_pulses.jl
index ecdcfe8..35db4b2 100644
--- a/src/pulses/sinc_pulses.jl
+++ b/src/pulses/sinc_pulses.jl
@@ -1,11 +1,11 @@
 module SincPulses
 
-import JuMP: VariableRef, @constraint, @variable, value, Model
+import JuMP: VariableRef, @constraint, @variable, value
 import QuadGK: quadgk
 import Polynomials: fit, Polynomial
 import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!, fixed
 import ...Variables: flip_angle, phase, amplitude, frequency, VariableType, variables, get_free_variable, duration, effective_time, inverse_bandwidth
-import ...BuildSequences: @global_model_constructor
+import ...BuildSequences: global_model
 import ..FixedPulses: FixedPulse
 
 """
@@ -29,7 +29,6 @@ Represents an radio-frequency pulse with a constant amplitude and frequency.
 - [`bandwidth`](@ref): width of the rectangular function in frequency space (in kHz). If the `duration` is short (compared with 1/`bandwidth`), this bandwidth will only be approximate.
 """
 struct SincPulse <: RFPulseBlock
-    model :: Model
     symmetric :: Bool
     apodise :: Bool
     nlobe_integral :: Polynomial
@@ -42,31 +41,29 @@ struct SincPulse <: RFPulseBlock
     scale :: Union{Nothing, Symbol}
 end
 
-@global_model_constructor SincPulse
-
-function SincPulse(model::Model; 
+function SincPulse(; 
     symmetric=true, max_Nlobes=nothing, apodise=true, N_lobes=nothing, N_left=nothing, N_right=nothing, 
     amplitude=nothing, phase=nothing, frequency=nothing, lobe_duration=nothing, scale=nothing, kwargs...
 ) 
     if symmetric
-        N_lobes = get_free_variable(model, N_lobes)
+        N_lobes = get_free_variable(N_lobes)
         @assert isnothing(N_left) && isnothing(N_right) "N_left and N_right cannot be set if symmetric=true (default)"
         N_left_var = N_right_var = N_lobes
     else
         @assert isnothing(N_lobes) "N_lobes cannot be set if symmetric=true (default)"
-        N_left_var = get_free_variable(model, N_left)
-        N_right_var = get_free_variable(model, N_right)
+        N_left_var = get_free_variable(N_left)
+        N_right_var = get_free_variable(N_right)
     end
     res = SincPulse(
-        model,
         symmetric,
         apodise,
         nlobe_integral_params(max_Nlobes, apodise),
         N_left_var,
         N_right_var,
-        [get_free_variable(model, value) for value in (amplitude, phase, frequency, lobe_duration)]...,
+        [get_free_variable(value) for value in (amplitude, phase, frequency, lobe_duration)]...,
         scale
     )
+    model = global_model()
     @constraint model res.amplitude >= 0
     @constraint model res.N_left >= 1
     if !symmetric
diff --git a/src/scanners.jl b/src/scanners.jl
index b5ddf9a..6c026a8 100644
--- a/src/scanners.jl
+++ b/src/scanners.jl
@@ -2,7 +2,7 @@
 Define general [`Scanner`](@ref) type and methods as well as some concrete scanners.
 """
 module Scanners
-import JuMP: Model, @constraint, owner_model
+import JuMP: Model, @constraint
 import ..Variables: gradient_strength, slew_rate
 import ..BuildingBlocks: BuildingBlock, get_children_blocks, ContainerBlock
 import ..Variables: variables
@@ -89,18 +89,14 @@ predefined_scanners = Dict(
 )
 
 """
-    scanner_constraints!([model, ]building_block[, scanner])
+    scanner_constraints!(building_block[, scanner])
 
 Adds any constraints from a specific scanner to a [`BuildingBlock`]{@ref}.
 """
 function scanner_constraints!(building_block::BuildingBlock, scanner::Scanner)
-    scanner_constraints!(owner_model(building_block), building_block, scanner)
-end
-
-function scanner_constraints!(model::Model, building_block::BuildingBlock, scanner::Scanner)
     for func in [gradient_strength, slew_rate]
         if isfinite(func(scanner))
-            scanner_constraints!(model, building_block, scanner, func)
+            scanner_constraints!(building_block, scanner, func)
         end
     end
 end
diff --git a/src/sequences.jl b/src/sequences.jl
index e6f2947..bd5910c 100644
--- a/src/sequences.jl
+++ b/src/sequences.jl
@@ -3,8 +3,8 @@ Define the [`Sequence`](@ref) building block.
 """
 module Sequences
 import Printf: @sprintf
-import JuMP: Model, @constraint
-import ...BuildSequences: @global_model_constructor
+import JuMP: @constraint
+import ...BuildSequences: global_model
 import ...Variables: variables, start_time, duration, VariableType, get_free_variable, TR, end_time
 import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks
 
@@ -21,25 +21,21 @@ or be embedded as a [`BuildingBlock`](@ref) into higher-order `Sequence` or othe
 - [`TR`](@ref): repetition time of sequence in ms.
 """
 struct Sequence <: ContainerBlock
-    model :: Model
     _blocks :: Vector{<:BuildingBlock}
     TR :: VariableType
-    function Sequence(model::Model, blocks::AbstractVector; TR=nothing)
+    function Sequence(blocks::AbstractVector; TR=nothing)
         seq = new(
-            model,
             to_block.(blocks),
-            get_free_variable(model, TR),
+            get_free_variable(TR),
         )
         if !(TR isa Number && isinf(TR))
-            @constraint model seq.TR >= duration(seq)
+            @constraint global_model() seq.TR >= duration(seq)
         end
         return seq
     end
 end
 
-@global_model_constructor Sequence
-
-Sequence(model::Model, blocks...; TR=nothing) = Sequence(model, [blocks...]; TR=TR)
+Sequence(blocks...; TR=nothing) = Sequence([blocks...]; TR=TR)
 
 Base.length(seq::Sequence) = length(seq._blocks)
 Base.getindex(seq::Sequence, index) = seq._blocks[index]
diff --git a/src/variables.jl b/src/variables.jl
index 01afc14..d1b7f0d 100644
--- a/src/variables.jl
+++ b/src/variables.jl
@@ -1,5 +1,6 @@
 module Variables
-import JuMP: @variable, Model, @objective, objective_function, owner_model, has_values, value, AbstractJuMPScalar
+import JuMP: @variable, Model, @objective, objective_function, value, AbstractJuMPScalar
+import ..
 
 all_variables_symbols = [
     :block => [
@@ -82,29 +83,23 @@ const VariableType = Union{Number, AbstractJuMPScalar}
 
 
 """
-    get_free_variable(model, value; integer=false)
+    get_free_variable(value; integer=false)
 
 Get a representation of a given `variable` given a user-defined constraint.
 """
-get_free_variable(::Model, value::Number; integer=false) = integer ? Int(value) : Float64(value)
-function get_free_variable(model::Model, value::VariableType; integer=false)
-    if owner_model(value) != model
-        if has_values(value)
-            return value(value)
-        end
-        error("Cannot set any constraints between sequences stored in different JuMP models.")
-    end
-    return value
-end
-get_free_variable(model::Model, ::Nothing; integer=false) = @variable(model, start=0.01, integer=integer)
-get_free_variable(model::Model, value::Symbol; integer=false) = integer ? error("Cannot maximise or minimise an integer variable") : get_free_variable(model, Val(value))
-function get_free_variable(model::Model, ::Val{:min})
-    var = get_free_variable(model, nothing)
+get_free_variable(value::Number; integer=false) = integer ? Int(value) : Float64(value)
+get_free_variable(value::VariableType; integer=false) = value
+get_free_variable(::Nothing; integer=false) = @variable(global_model(), start=0.01, integer=integer)
+get_free_variable(value::Symbol; integer=false) = integer ? error("Cannot maximise or minimise an integer variable") : get_free_variable(Val(value))
+function get_free_variable(::Val{:min})
+    var = get_free_variable(nothing)
+    model = global_model()
     @objective model Min objective_function(model) + var
     return var
 end
-function get_free_variable(model::Model, ::Val{:max})
-    var = get_free_variable(model, nothing)
+function get_free_variable(::Val{:max})
+    var = get_free_variable(nothing)
+    model = global_model()
     @objective model Min objective_function(model) - var
     return var
 end
-- 
GitLab