From 2bb13c0948cbf8742f0cae8214ab53efc32e4786 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <MichielCottaar@protonmail.com>
Date: Thu, 25 Jul 2024 17:27:12 +0100
Subject: [PATCH] ENH: use iterative cost functions

This has been implemented using a new add_cost_function! function
---
 src/MRIBuilder.jl                         |   7 +-
 src/build_sequences.jl                    | 105 ++++++++++++----------
 src/components/instant_gradients.jl       |   1 -
 src/components/pulses/composite_pulses.jl |   7 --
 src/components/pulses/constant_pulses.jl  |   1 -
 src/components/pulses/instant_pulses.jl   |   1 -
 src/components/pulses/sinc_pulses.jl      |   1 -
 src/components/readouts/ADCs.jl           |   2 +-
 src/containers/alternatives.jl            |  10 +--
 src/containers/base_sequences.jl          |   2 +-
 src/containers/building_blocks.jl         |   1 -
 src/parts/helper_functions.jl             |   2 +-
 src/parts/spoilt_slice_selects.jl         |  13 ++-
 src/parts/trapezoids.jl                   |  10 ++-
 src/variables.jl                          |  44 +++++----
 15 files changed, 106 insertions(+), 101 deletions(-)

diff --git a/src/MRIBuilder.jl b/src/MRIBuilder.jl
index 591f342..2c8c615 100644
--- a/src/MRIBuilder.jl
+++ b/src/MRIBuilder.jl
@@ -16,8 +16,8 @@ include("printing.jl")
 include("sequence_io/sequence_io.jl")
 include("plot.jl")
 
-import .BuildSequences: build_sequence, global_model, global_scanner, fixed
-export build_sequence, global_model, global_scanner, fixed
+import .BuildSequences: build_sequence, global_scanner, fixed
+export build_sequence, global_scanner, fixed
 
 import .Scanners: Scanner, B0, Siemens_Connectom, Siemens_Prisma, Siemens_Terra, Default_Scanner
 export Scanner, B0, Siemens_Connectom, Siemens_Prisma, Siemens_Terra, Default_Scanner
@@ -49,7 +49,4 @@ export read_sequence, write_sequence
 import .Plot: plot_sequence
 export plot_sequence
 
-import JuMP: @constraint, @objective, objective_function, value, Model
-export @constraint, @objective, objective_function, value, Model
-
 end
diff --git a/src/build_sequences.jl b/src/build_sequences.jl
index 0b78486..cc8ecab 100644
--- a/src/build_sequences.jl
+++ b/src/build_sequences.jl
@@ -1,14 +1,30 @@
 module BuildSequences
-import JuMP: Model, optimizer_with_attributes, optimize!, AbstractJuMPScalar, value, solution_summary, termination_status, LOCALLY_SOLVED, OPTIMAL, num_variables, all_variables, set_start_value, ALMOST_LOCALLY_SOLVED, objective_value, INVALID_MODEL, @variable
+import JuMP: Model, optimizer_with_attributes, optimize!, AbstractJuMPScalar, value, solution_summary, termination_status, LOCALLY_SOLVED, OPTIMAL, num_variables, all_variables, set_start_value, ALMOST_LOCALLY_SOLVED, objective_value, INVALID_MODEL, @variable, @objective, @constraint
 import Ipopt
 import Juniper
 import ..Scanners: Scanner, gradient_strength, Default_Scanner
 
-const GLOBAL_MODEL = Ref(Model())
+const GLOBAL_MODEL = Ref((Model(), Tuple{Float64, AbstractJuMPScalar}[]))
 const IGNORE_MODEL = GLOBAL_MODEL[]
 
 const GLOBAL_SCANNER = Ref(Scanner())
 
+
+"""
+    iterate_cost()
+
+Return a sequence of all the cost functions that should be optimised in order.
+"""
+function iterate_cost()
+    unique_weights = sort(unique([w for (w, _) in GLOBAL_MODEL[][2]]))
+    if iszero(length(unique_weights))
+        return [0.]
+    else
+        return [sum([f for (w, f) in GLOBAL_MODEL[][2] if w == weight]) for weight in unique_weights]
+    end
+end
+
+
 """
 Wrapper to build a sequence.
 
@@ -27,10 +43,10 @@ seq = Sequence(
 )
 ```
 
-You can also add any arbitrary constraints or objectives using the same syntax as for [`JuMP`](https://jump.dev/JuMP.jl):
-```
-@constraint global_model() duration(seq) == 30.
-```
+You can also add any arbitrary constraints or objectives using one of:
+- [`set_simple_constraints!`](@ref)
+- [`apply_simple_constraint!`](@ref)
+- [`add_cost_function!`](@ref)
 
 As soon as the code block ends the sequence is optimised (if `optimise=true`) and returned.
 
@@ -41,7 +57,7 @@ As soon as the code block ends the sequence is optimised (if `optimise=true`) an
 - `n_attempts`: How many times to restart the optimser (default: 100).
 - `kwargs...`: Other keywords are passed on as attributes to the `optimiser_constructor` (e.g., set `print_level=3` to make the Ipopt optimiser quieter).
 """
-function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, model::Model, optimise::Bool, n_attempts::Int)
+function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, model::Tuple, optimise::Bool, n_attempts::Int)
     prev_model = GLOBAL_MODEL[]
     GLOBAL_MODEL[] = model
     prev_scanner = GLOBAL_SCANNER[]
@@ -53,34 +69,38 @@ function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, model::Mo
     try
         sequence = f()
         if optimise
-            if !iszero(num_variables(model))
-                min_objective = Inf
-                for attempt in 1:n_attempts
-                    if attempt != 1
-                        old_values = value.(all_variables(model))
-                        size_kick = 0.5 / attempt
-                        new_values = old_values .* (2 .* size_kick .* rand(length(old_values)) .+ 1. .- size_kick)
-                        for (var, v) in zip(all_variables(model), new_values)
-                            set_start_value(var, v)
+            jump_model = GLOBAL_MODEL[][1]
+            if !iszero(num_variables(jump_model))
+                for cost_func in iterate_cost()
+                    @objective jump_model Min cost_func
+                    min_objective = Inf
+                    for attempt in 1:n_attempts
+                        if attempt != 1
+                            old_values = value.(all_variables(jump_model))
+                            size_kick = 0.5 / attempt
+                            new_values = old_values .* (2 .* size_kick .* rand(length(old_values)) .+ 1. .- size_kick)
+                            for (var, v) in zip(all_variables(jump_model), new_values)
+                                set_start_value(var, v)
+                            end
                         end
-                    end
-                    optimize!(model)
-                    while termination_status(model) == INVALID_MODEL
-                        @variable(model)
-                        optimize!(model)
-                    end
-                    if termination_status(model) in (LOCALLY_SOLVED, OPTIMAL)
-                        if objective_value(model) < min_objective
-                            min_objective = objective_value(model)
-                        elseif isapprox(min_objective, objective_value(model), rtol=1e-6)
-                            println("Optimisation succeeded after $(attempt-1) restarts.")
-                            break
+                        optimize!(jump_model)
+                        while termination_status(jump_model) == INVALID_MODEL
+                            @variable(jump_model)
+                            optimize!(jump_model)
+                        end
+                        if termination_status(jump_model) in (LOCALLY_SOLVED, OPTIMAL)
+                            if objective_value(jump_model) < min_objective
+                                min_objective = objective_value(jump_model)
+                            elseif isapprox(min_objective, objective_value(jump_model), rtol=1e-6)
+                                break
+                            end
                         end
                     end
-                end
-                if !(termination_status(model) in (LOCALLY_SOLVED, OPTIMAL))
-                    @warn "Optimisation did not report successful convergence. Please check the output sequence."
-                    println(solution_summary(model))
+                    if !(termination_status(jump_model) in (LOCALLY_SOLVED, OPTIMAL))
+                        println(solution_summary(jump_model))
+                        error("Optimisation failed to converge.")
+                    end
+                    #@constraint jump_model cost_func == objective_value(jump_model)
                 end
             end
             return fixed(sequence)
@@ -95,9 +115,12 @@ end
 
 function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, optimiser_constructor; optimise=true, n_attempts=100, kwargs...)
     if optimise || GLOBAL_MODEL[] == IGNORE_MODEL
-        model = Model(optimizer_with_attributes(optimiser_constructor, [string(k) => v for (k, v) in kwargs]...))
+        model = (
+            Model(optimizer_with_attributes(optimiser_constructor, [string(k) => v for (k, v) in kwargs]...)),
+            Tuple{Float64, AbstractJuMPScalar}[]
+        )
     else
-        model = global_model()
+        model = GLOBAL_MODEL[]
     end
     build_sequence(f, scanner, model, optimise, n_attempts)
 end
@@ -109,20 +132,6 @@ end
 build_sequence(f::Function, optimiser_constructor; kwargs...) = build_sequence(f, Default_Scanner, optimiser_constructor; kwargs...)
 
 
-"""
-    global_model()
-
-Return the currently set JuMP model.
-
-The model can be set using [`build_sequence`](@ref)
-"""
-function global_model()
-    if GLOBAL_MODEL[] == IGNORE_MODEL
-        error("No global model has been set. Please embed any sequence creation within an `build_sequence` block.")
-    end
-    return GLOBAL_MODEL[]
-end
-
 """
     global_scanner()
 
diff --git a/src/components/instant_gradients.jl b/src/components/instant_gradients.jl
index a3ce9fc..da550bf 100644
--- a/src/components/instant_gradients.jl
+++ b/src/components/instant_gradients.jl
@@ -2,7 +2,6 @@ module InstantGradients
 import StaticArrays: SVector, SMatrix
 import JuMP: @constraint
 import ...Variables: @defvar, VariableType, variables, get_free_variable, set_simple_constraints!, make_generic, adjust_internal, adjustable, gradient_orientation, apply_simple_constraint!
-import ...BuildSequences: global_model
 import ..AbstractTypes: EventComponent, GradientWaveform
 
 """
diff --git a/src/components/pulses/composite_pulses.jl b/src/components/pulses/composite_pulses.jl
index 6a3c2a9..ad91e1b 100644
--- a/src/components/pulses/composite_pulses.jl
+++ b/src/components/pulses/composite_pulses.jl
@@ -1,7 +1,6 @@
 module CompositePulses
 import JuMP: @constraint
 import ...AbstractTypes: RFPulseComponent, split_timestep, edge_times
-import ....BuildSequences: global_model
 import ....Variables: VariableType, set_simple_constraints!, make_generic, get_free_variable, adjust_internal, variables, @defvar
 import ..GenericPulses: GenericPulse
 
@@ -49,12 +48,6 @@ function CompositePulse(; base_pulse::RFPulseComponent, nweights=nothing, weight
         scale_amplitude
     )
     return res
-    @constraint global_model() minimum(wait_times(res)) >= 0.
-    if !(res.scale_amplitude isa Number)
-        @constraint global_model() res.scale_amplitude >= 0.
-        @constraint global_model() res.scale_amplitude <= 1.
-    end
-    return res
 end
 
 Base.length(comp::CompositePulse) = length(comp.pulses)
diff --git a/src/components/pulses/constant_pulses.jl b/src/components/pulses/constant_pulses.jl
index fd591d9..4af0975 100644
--- a/src/components/pulses/constant_pulses.jl
+++ b/src/components/pulses/constant_pulses.jl
@@ -1,7 +1,6 @@
 module ConstantPulses
 import JuMP: @constraint
 import ...AbstractTypes: RFPulseComponent, split_timestep
-import ....BuildSequences: global_model
 import ....Variables: VariableType, set_simple_constraints!, make_generic, get_free_variable, adjust_internal, variables, @defvar, apply_simple_constraint!
 import ..GenericPulses: GenericPulse
 
diff --git a/src/components/pulses/instant_pulses.jl b/src/components/pulses/instant_pulses.jl
index 14b81c6..2529668 100644
--- a/src/components/pulses/instant_pulses.jl
+++ b/src/components/pulses/instant_pulses.jl
@@ -1,7 +1,6 @@
 module InstantPulses
 import JuMP: @constraint
 import ...AbstractTypes: RFPulseComponent
-import ....BuildSequences: global_model
 import ....Variables: VariableType, make_generic, get_free_variable, adjust_internal, variables, @defvar, apply_simple_constraint!
 
 """
diff --git a/src/components/pulses/sinc_pulses.jl b/src/components/pulses/sinc_pulses.jl
index 662c590..b445ef1 100644
--- a/src/components/pulses/sinc_pulses.jl
+++ b/src/components/pulses/sinc_pulses.jl
@@ -1,7 +1,6 @@
 module SincPulses
 import JuMP: @constraint
 import QuadGK: quadgk
-import ....BuildSequences: global_model
 import ....Variables: VariableType, set_simple_constraints!, make_generic, get_free_variable, adjust_internal, variables, @defvar, apply_simple_constraint!
 import ...AbstractTypes: RFPulseComponent, split_timestep
 import ..GenericPulses: GenericPulse
diff --git a/src/components/readouts/ADCs.jl b/src/components/readouts/ADCs.jl
index 7098cf4..23894cd 100644
--- a/src/components/readouts/ADCs.jl
+++ b/src/components/readouts/ADCs.jl
@@ -1,7 +1,7 @@
 module ADCs
 import JuMP: @constraint, value
 import ...AbstractTypes: ReadoutComponent
-import ....BuildSequences: global_model, fixed
+import ....BuildSequences: fixed
 import ....Variables: VariableType, apply_simple_constraint!, set_simple_constraints!, get_free_variable, make_generic, variables, @defvar
 
 
diff --git a/src/containers/alternatives.jl b/src/containers/alternatives.jl
index 731042b..8847d5f 100644
--- a/src/containers/alternatives.jl
+++ b/src/containers/alternatives.jl
@@ -1,8 +1,8 @@
 module Alternatives
 import JuMP: @constraint
 import ..Abstract: ContainerBlock
-import ...BuildSequences: global_model, fixed
-import ...Variables: @defvar, make_generic
+import ...BuildSequences: fixed
+import ...Variables: @defvar, make_generic, apply_simple_constraint!
 
 """
     AlternativeBlocks(name, blocks)
@@ -36,11 +36,7 @@ function match_blocks!(alternatives::AlternativeBlocks, func)
     options = [values(alternatives.options)...]
     baseline = func(options[1])
     for other_block in options[2:end]
-        if baseline isa AbstractVector
-            @constraint global_model() baseline == func(other_block)
-        else
-            @constraint global_model() baseline .== func(other_block)
-        end
+        apply_simple_constraint!(func(other_block), baseline)
     end
 end
 
diff --git a/src/containers/base_sequences.jl b/src/containers/base_sequences.jl
index c26e2c7..d595640 100644
--- a/src/containers/base_sequences.jl
+++ b/src/containers/base_sequences.jl
@@ -5,7 +5,7 @@ module BaseSequences
 import StaticArrays: SVector
 import JuMP: @constraint
 import ...Variables: get_free_variable, VariableType, variables, set_simple_constraints!, make_generic, get_gradient, get_pulse, get_gradient, @defvar
-import ...BuildSequences: global_model, global_scanner
+import ...BuildSequences: global_scanner
 import ...Components: EventComponent, NoGradient, edge_times
 import ...Scanners: Scanner, B0
 import ..Abstract: ContainerBlock, start_time
diff --git a/src/containers/building_blocks.jl b/src/containers/building_blocks.jl
index 90e36e7..d270a9f 100644
--- a/src/containers/building_blocks.jl
+++ b/src/containers/building_blocks.jl
@@ -6,7 +6,6 @@ import LinearAlgebra: norm
 import JuMP: @constraint
 import StaticArrays: SVector
 import ..Abstract: ContainerBlock, start_time, end_time, iter
-import ...BuildSequences: global_model
 import ...Components: BaseComponent, GradientWaveform, EventComponent, NoGradient, ChangingGradient, ConstantGradient, split_gradient, RFPulseComponent, ReadoutComponent, InstantGradient, edge_times
 import ...Variables: VariableType, make_generic, get_pulse, get_readout, scanner_constraints!, get_gradient, gradient_orientation, variables, @defvar, get_free_variable, apply_simple_constraint!
 
diff --git a/src/parts/helper_functions.jl b/src/parts/helper_functions.jl
index 9de5650..66a307a 100644
--- a/src/parts/helper_functions.jl
+++ b/src/parts/helper_functions.jl
@@ -5,7 +5,7 @@ import ..Trapezoids: Trapezoid, opposite_kspace_lines, SliceSelect
 import ..SpoiltSliceSelects: SpoiltSliceSelect
 import ..SliceSelectRephases: SliceSelectRephase
 import ..EPIReadouts: EPIReadout
-import ...BuildSequences: global_model, build_sequence, global_scanner
+import ...BuildSequences: build_sequence, global_scanner
 import ...Containers: Sequence
 import ...Components: SincPulse, ConstantPulse, InstantPulse, SingleReadout, InstantGradient
 import ...Variables: variables, apply_simple_constraint!
diff --git a/src/parts/spoilt_slice_selects.jl b/src/parts/spoilt_slice_selects.jl
index 349f58a..adef210 100644
--- a/src/parts/spoilt_slice_selects.jl
+++ b/src/parts/spoilt_slice_selects.jl
@@ -3,7 +3,7 @@ module SpoiltSliceSelects
 import LinearAlgebra: norm
 import StaticArrays: SVector
 import JuMP: @constraint, @objective, objective_function
-import ...BuildSequences: global_model, global_scanner
+import ...BuildSequences: global_scanner
 import ...Variables: VariableType, get_pulse, set_simple_constraints!, variables, @defvar, gradient_orientation
 import ...Components: ChangingGradient, ConstantGradient, RFPulseComponent
 import ...Containers: BaseBuildingBlock
@@ -36,7 +36,6 @@ struct SpoiltSliceSelect <: BaseBuildingBlock
 end
 
 function SpoiltSliceSelect(pulse::RFPulseComponent; orientation=[0, 0, 1], group=nothing, slice_thickness=nothing, kwargs...)
-    model = global_model()
     res = nothing
     if slice_thickness isa Number && isinf(slice_thickness)
         rise_time_var = get_free_variable(nothing)
@@ -70,16 +69,16 @@ function SpoiltSliceSelect(pulse::RFPulseComponent; orientation=[0, 0, 1], group
         for time_var in (res.rise_time1, res.flat_time1, res.diff_time, res.flat_time2, res.fall_time2)
             apply_simple_constraint!(time_var, :>=, 0)
         end
-        @constraint model res.diff_time <= res.rise_time1
-        @constraint model res.diff_time <= res.fall_time2
-        @constraint model qvec(res, nothing, :pulse) == qvec(res, :pulse, nothing)
+        apply_simple_constraint!(res.diff_time, :<=, res.rise_time1)
+        apply_simple_constraint!(res.diff_time, :<=, res.rise_time2)
+        apply_simple_constraint!(qvec(res, nothing, :pulse), qvec(res, :pulse, nothing))
     end
 
     set_simple_constraints!(res, kwargs)
 
     max_time = gradient_strength(global_scanner()) / res.slew_rate
-    @constraint model rise_time(res)[1] <= max_time
-    @constraint model fall_time(res)[2] <= max_time
+    apply_simple_constraint!(rise_time(res)[1], :<=, max_time)
+    apply_simple_constraint!(fall_time(res)[2], :<=, max_time)
     return res
 end
 
diff --git a/src/parts/trapezoids.jl b/src/parts/trapezoids.jl
index c3db0ce..ddc37b9 100644
--- a/src/parts/trapezoids.jl
+++ b/src/parts/trapezoids.jl
@@ -7,7 +7,6 @@ import JuMP: @constraint
 import StaticArrays: SVector
 import LinearAlgebra: norm
 import ...Variables: variables, @defvar, scanner_constraints!, get_free_variable, set_simple_constraints!, gradient_orientation, VariableType, get_gradient, get_pulse, get_readout, adjustable, adjust_internal, apply_simple_constraint!
-import ...BuildSequences: global_model
 import ...Components: ChangingGradient, ConstantGradient, RFPulseComponent, ADC
 import ...Containers: BaseBuildingBlock
 
@@ -249,10 +248,13 @@ function LineReadout(adc::ADC; ramp_overlap=nothing, orientation=nothing, group=
     )
     set_simple_constraints!(res, vars)
     if !(res.ramp_overlap isa Number)
-        @constraint global_model() res.ramp_overlap >= 0
-        @constraint global_model() res.ramp_overlap <= 1
+        add_simple_constraint!(res.ramp_overlap, :>=, 0.)
+        add_simple_constraint!(res.ramp_overlap, :<=, 1.)
     end
-    @constraint global_model() (res.ramp_overlap * variables.rise_time(res.trapezoid) * 2 + variables.flat_time(res.trapezoid)) == variables.duration(res.adc)
+    add_simple_constraint!(
+        res.ramp_overlap * variables.rise_time(res.trapezoid) * 2 + variables.flat_time(res.trapezoid),
+        variables.duration(res.adc)
+    )
     return res
 end
 
diff --git a/src/variables.jl b/src/variables.jl
index aec9c86..76184d4 100644
--- a/src/variables.jl
+++ b/src/variables.jl
@@ -5,6 +5,7 @@ In addition this defines:
 - [`variables`](@ref): module containing all variables.
 - [`VariableType`](@ref): parent type for any variables (whether number or JuMP variable).
 - [`get_free_variable`](@ref): helper function to create new JuMP variables.
+- [`add_cost_function!`](@ref): add a specific term to the model cost functions.
 - [`set_simple_constraints!`](@ref): call [`apply_simple_constraint!`](@ref) for each keyword argument.
 - [`apply_simple_constraint!`](@ref): set a simple equality constraint.
 - [`get_pulse`](@ref)/[`get_gradient`](@ref)/[`get_readout`](@ref): Used to get the pulse/gradient/readout part of a building block
@@ -15,7 +16,7 @@ import JuMP: @constraint, @variable, Model, @objective, objective_function, Abst
 import StaticArrays: SVector
 import MacroTools
 import ..Scanners: gradient_strength, slew_rate, Scanner
-import ..BuildSequences: global_model, global_scanner, fixed
+import ..BuildSequences: global_scanner, fixed, GLOBAL_MODEL
 
 """
 Parent type of all components, building block, and sequences that form an MRI sequence.
@@ -343,18 +344,16 @@ The result is guaranteed to be a [`VariableType`](@ref).
 """
 get_free_variable(value::Number; integer=false, kwargs...) = integer ? Int(value) : Float64(value)
 get_free_variable(value::VariableType; kwargs...) = value
-get_free_variable(::Nothing; integer=false, start=0.01) = @variable(global_model(), start=start, integer=integer)
+get_free_variable(::Nothing; integer=false, start=0.01) = @variable(GLOBAL_MODEL[][1], start=start, integer=integer)
 get_free_variable(value::Symbol; integer=false, kwargs...) = integer ? error("Cannot maximise or minimise an integer variable") : get_free_variable(Val(value); kwargs...)
 function get_free_variable(::Val{:min}; kwargs...)
     var = get_free_variable(nothing; kwargs...)
-    model = global_model()
-    @objective model Min objective_function(model) + var
+    add_cost_function!(var, 1)
     return var
 end
 function get_free_variable(::Val{:max}; kwargs...)
     var = get_free_variable(nothing; kwargs...)
-    model = global_model()
-    @objective model Min objective_function(model) - var
+    add_cost_function!(-var, 1)
     return var
 end
 
@@ -453,6 +452,21 @@ function (var::AlternateVariable)(args...; kwargs...)
 end
 
 
+"""
+    add_cost_function!(function, level=2)
+
+Adds an additional term to the cost function.
+
+This term will be minimised together with any other terms in the cost function.
+Terms added at a lower level will be optimised before any terms with a higher level.
+
+By default, the term is added to the `level=2`, which is appropriate for any cost functions added by the developer,
+which will generally only be optimised after any user-defined cost functions (which are added at `level=1` by [`add_simple_constraint!`](@ref) or [`set_simple_constraints!`](@ref).
+"""
+function add_cost_function!(func, level=2)
+    push!(GLOBAL_MODEL[][2], (Float64(level), func))
+end
+
 """
     set_simple_constraints!(block, kwargs)
 
@@ -507,9 +521,9 @@ apply_simple_constraint!(variable::AbstractVector, value::Symbol) = apply_simple
 apply_simple_constraint!(variable, value::Nothing) = nothing
 apply_simple_constraint!(variable::NamedTuple, value::Nothing) = nothing
 apply_simple_constraint!(variable::VariableType, value::Symbol) = apply_simple_constraint!(variable, Val(value))
-apply_simple_constraint!(variable::VariableType, ::Val{:min}) = @objective global_model() Min objective_function(global_model()) + variable
-apply_simple_constraint!(variable::VariableType, ::Val{:max}) = @objective global_model() Min objective_function(global_model()) - variable
-apply_simple_constraint!(variable::VariableType, value::VariableType) = @constraint global_model() variable == value
+apply_simple_constraint!(variable::VariableType, ::Val{:min}) = add_cost_function!(variable, 1)
+apply_simple_constraint!(variable::VariableType, ::Val{:max}) = add_cost_function!(-variable, 1)
+apply_simple_constraint!(variable::VariableType, value::VariableType) = @constraint GLOBAL_MODEL[][1] variable == value
 apply_simple_constraint!(variable::AbstractVector, value::AbstractVector) = [apply_simple_constraint!(v1, v2) for (v1, v2) in zip(variable, value)]
 apply_simple_constraint!(variable::AbstractVector, value::VariableType) = [apply_simple_constraint!(v1, value) for v1 in variable]
 apply_simple_constraint!(variable::Number, value::Number) = @assert variable ≈ value "Variable set to multiple incompatible values."
@@ -526,12 +540,12 @@ end
 
 apply_simple_constraint!(variable::VariableType, sign::Symbol, value) = apply_simple_constraint!(variable, Val(sign), value)
 apply_simple_constraint!(variable::VariableType, ::Val{:(==)}, value) = apply_simple_constraint!(variable, value)
-apply_simple_constraint!(variable::VariableType, ::Val{:(<=)}, value::VariableType) = apply_simple_constraint!(variable, Val(:>=), -value)
+apply_simple_constraint!(variable::VariableType, ::Val{:(<=)}, value::VariableType) = apply_simple_constraint!(value, Val(:>=), variable)
 function apply_simple_constraint!(variable::VariableType, ::Val{:(>=)}, value::VariableType)
-    @constraint global_model() variable >= value
+    @constraint GLOBAL_MODEL[][1] variable >= value
     if value isa Number && iszero(value)
-        @constraint global_model() variable * 1e6 >= value
-        @constraint global_model() variable * 1e12 >= value
+        @constraint GLOBAL_MODEL[][1] variable * 1e6 >= value
+        @constraint GLOBAL_MODEL[][1] variable * 1e12 >= value
     end
 end
 
@@ -569,8 +583,8 @@ function scanner_constraints!(bb::AbstractBlock)
             if v isa Number || ((v isa Union{QuadExpr, AffExpr}) && length(v.terms) == 0)
                 continue
             end
-            @constraint global_model() v <= max_value
-            @constraint global_model() v >= -max_value
+            apply_simple_constraint!(v, :<=, max_value)
+            apply_simple_constraint!(v, :>=, -max_value)
         end
     end
 end
-- 
GitLab