From 6d2b35614ead8efda31854ae61fcc796282c6d58 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Fri, 23 Feb 2024 10:58:28 +0000
Subject: [PATCH] Adjust values before rather than after optimisation

Also, add nattempts as a keyword argument
---
 src/build_sequences.jl | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/src/build_sequences.jl b/src/build_sequences.jl
index 3b7cf75..2252aaf 100644
--- a/src/build_sequences.jl
+++ b/src/build_sequences.jl
@@ -37,12 +37,11 @@ As soon as the code block ends the sequence is optimised (if `optimise=true`) an
 ## Parameters
 - `scanner`: Set to a [`Scanner`](@ref) to limit the gradient strength and slew rate. When this call to `build_sequence` is embedded in another, this parameter can be set to `nothing` to indicate that the same scanner should be used. 
 - `optimiser_constructor`: A `JuMP` solver optimiser as described in the [JuMP documentation](https://jump.dev/JuMP.jl/stable/tutorials/getting_started/getting_started_with_JuMP/#What-is-a-solver?). Defaults to using [Ipopt](https://github.com/jump-dev/Ipopt.jl).
-
-## Variables
 - `optimise`: Whether to optimise and fix the sequence as soon as it is returned. This defaults to `true` if a scanner is provided and `false` if no scanner is provided.
+- `n_attempts`: How many times to restart the optimser (default: 100).
 - `kwargs...`: Other keywords are passed on as attributes to the `optimiser_constructor` (e.g., set `print_level=3` to make the Ipopt optimiser quieter).
 """
-function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, model::Model, optimise::Bool)
+function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, model::Model, optimise::Bool, n_attempts::Int)
     prev_model = GLOBAL_MODEL[]
     GLOBAL_MODEL[] = model
     prev_scanner = GLOBAL_SCANNER[]
@@ -55,12 +54,8 @@ function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, model::Mo
         sequence = f()
         if optimise
             if !iszero(num_variables(model))
-                for attempt in 1:100
-                    optimize!(model)
-                    if termination_status(model) in (LOCALLY_SOLVED, OPTIMAL)
-                        println("Optimisation succeeded after $(attempt-1) restarts.")
-                        break
-                    else
+                for attempt in 1:n_attempts
+                    if attempt != 1
                         old_values = value.(all_variables(model))
                         size_kick = 0.5 / attempt
                         new_values = old_values .* (2 .* size_kick .* rand(length(old_values)) .+ 1. .- size_kick)
@@ -68,6 +63,11 @@ function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, model::Mo
                             set_start_value(var, v)
                         end
                     end
+                    optimize!(model)
+                    if termination_status(model) in (LOCALLY_SOLVED, OPTIMAL)
+                        println("Optimisation succeeded after $(attempt-1) restarts.")
+                        break
+                    end
                 end
                 if !(termination_status(model) in (LOCALLY_SOLVED, OPTIMAL))
                     @warn "Optimisation did not report successful convergence. Please check the output sequence."
@@ -84,13 +84,13 @@ function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, model::Mo
     end
 end
 
-function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, optimiser_constructor; optimise=true, kwargs...)
+function build_sequence(f::Function, scanner::Union{Nothing, Scanner}, optimiser_constructor; optimise=true, n_attempts=100, kwargs...)
     if optimise || GLOBAL_MODEL[] == IGNORE_MODEL
         model = Model(optimizer_with_attributes(optimiser_constructor, [string(k) => v for (k, v) in kwargs]...))
     else
         model = global_model()
     end
-    build_sequence(f, scanner, model, optimise)
+    build_sequence(f, scanner, model, optimise, n_attempts)
 end
 
 function build_sequence(f::Function, scanner::Union{Nothing, Scanner}; print_level=2, kwargs...)
-- 
GitLab