From e60211c3ccd065dbe45cbd16fbc4c140702af80f Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <MichielCottaar@protonmail.com>
Date: Wed, 18 Sep 2024 13:10:39 +0100
Subject: [PATCH] ENH: allow for integer variables

---
 src/build_sequences.jl | 50 +++++++++++++++++++++++++++++++++---------
 1 file changed, 40 insertions(+), 10 deletions(-)

diff --git a/src/build_sequences.jl b/src/build_sequences.jl
index 563afaf..9d5ba49 100644
--- a/src/build_sequences.jl
+++ b/src/build_sequences.jl
@@ -73,6 +73,7 @@ function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, model::Tu
         sequence = f()
         if optimise
             jump_model = GLOBAL_MODEL[][1]
+            set_optimizer(jump_model, get_optimiser(jump_model))
             if !iszero(num_variables(jump_model))
                 optimise_with_cost_func!(jump_model, total_cost_func(), n_attempts)
                 prev_cost_func = nothing
@@ -98,6 +99,37 @@ function number_equality_constraints(model::Model)
     sum([num_constraints(model, expr, comp) for (expr, comp) in JuMP.list_of_constraint_types(model) if comp <: MOI.EqualTo])
 end
 
+"""
+    model_has_integers(model)
+
+Returns true if the model contains integer variables.
+"""
+model_has_integers(model::Model) = (VariableRef, MOI.Integer) in list_of_constraint_types(model)
+
+"""
+    get_optimiser(model)
+
+Returns a JuMP solver (https://jump.dev/JuMP.jl/stable/installation/#Supported-solvers) appropriate for this model.
+"""
+function get_optimiser(model::Model)
+    base_ipopt = optimizer_with_attributes(
+        Ipopt.Optimizer, 
+        "print_level"=>0, 
+        "mu_strategy"=>"adaptive", 
+        "max_iter"=>1000,
+    )
+    if model_has_integers(model)
+         optimizer_with_attributes(
+            Juniper.Optimizer, 
+            "nl_solver" => base_ipopt,
+            "log_levels" => [],
+        )
+    else
+        return base_ipopt
+    end
+end
+
+
 function optimise_with_cost_func!(jump_model::Model, cost_func, n_attempts)
     @objective jump_model Min cost_func
     min_objective = Inf
@@ -130,7 +162,7 @@ function optimise_with_cost_func!(jump_model::Model, cost_func, n_attempts)
                 nsuccess += 1
                 if objective_value(jump_model) < min_objective
                     min_objective = objective_value(jump_model)
-                    min_values = copy(backend(jump_model).optimizer.model.inner.x)
+                    min_values = copy(get_inner_state(jump_model))
                 end
                 break
             elseif sub_attempt == 3
@@ -150,20 +182,18 @@ function optimise_with_cost_func!(jump_model::Model, cost_func, n_attempts)
         end
         error("Optimisation failed to converge. The following errors were raised: $err_string. Example errors for each type are printed above.")
     end
-    backend(jump_model).optimizer.model.inner.x .= min_values
+    get_inner_state(jump_model) .= min_values
     @assert value(cost_func) ≈ min_objective eps(min_objective)
 end
 
-function build_sequence(f::Function, scanner::Union{Nothing, Scanner}=Default_Scanner; optimise=true, n_attempts=20, print_level=0, mu_strategy="adaptive", max_iter=1000, kwargs...)
+get_inner_state(backend::Juniper.JuniperProblem) = backend.solution
+get_inner_state(backend::Ipopt.IpoptProblem) = backend.x
+get_inner_state(model::Model) = get_inner_state(backend(model).optimizer.model.inner)
+
+function build_sequence(f::Function, scanner::Union{Nothing, Scanner}=Default_Scanner; optimise=true, n_attempts=20)
     if optimise || GLOBAL_MODEL[] == IGNORE_MODEL
-        full_kwargs = Dict(
-            :print_level => print_level,
-            :mu_strategy => mu_strategy,
-            :max_iter => max_iter,
-            kwargs...
-        )
         model = (
-            Model(optimizer_with_attributes(Ipopt.Optimizer, [string(k) => v for (k, v) in full_kwargs]...)),
+            Model(),
             Tuple{Float64, AbstractJuMPScalar}[]
         )
     else
-- 
GitLab