From 76eb80627d71172dcc0ca9001eec69b3a5300c7d Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Fri, 2 Feb 2024 10:57:00 +0000
Subject: [PATCH] Replace variables with numbers after optimisation

---
 src/build_sequences.jl                    | 17 +++++++++++++--
 src/building_blocks.jl                    | 25 ++++++++++-------------
 src/gradients/changing_gradient_blocks.jl |  2 +-
 src/gradients/constant_gradient_blocks.jl |  2 +-
 src/overlapping/spoilt_slice_selects.jl   |  2 +-
 src/overlapping/trapezoid_gradients.jl    |  2 +-
 src/pulses/constant_pulses.jl             |  3 +--
 src/pulses/instant_pulses.jl              |  2 +-
 src/pulses/sinc_pulses.jl                 |  4 ++--
 src/readouts/instant_readouts.jl          |  3 +--
 src/sequences.jl                          |  3 ++-
 11 files changed, 37 insertions(+), 28 deletions(-)

diff --git a/src/build_sequences.jl b/src/build_sequences.jl
index 5600724..461e351 100644
--- a/src/build_sequences.jl
+++ b/src/build_sequences.jl
@@ -1,5 +1,5 @@
 module BuildSequences
-import JuMP: Model, optimizer_with_attributes, optimize!
+import JuMP: Model, optimizer_with_attributes, optimize!, AbstractJuMPScalar, value
 import Ipopt
 import Juniper
 import ..Scanners: Scanner
@@ -44,7 +44,7 @@ function build_sequence(f::Function, scanner::Scanner, model::Model)
     try
         sequence = f(model)
         optimize!(model)
-        return sequence
+        return fixed(sequence)
     finally
         GLOBAL_MODEL[] = prev_model
         GLOBAL_SCANNER[] = prev_scanner
@@ -74,4 +74,17 @@ function global_scanner()
 end
 
 
+"""
+    fixed(building_block)
+
+Returns an equiavalent [`BuildingBlock`](@ref) with all free variables replaced by numbers.
+This will only work after calling [`optimize!`](@ref)([`global_model`](@ref)()).
+It is used internally by [`build_sequence`](@ref).
+"""
+fixed(some_value) = some_value
+fixed(jump_variable::AbstractJuMPScalar) = value(jump_variable)
+fixed(jump_variable::AbstractArray) = fixed.(jump_variable)
+
+
+
 end
\ No newline at end of file
diff --git a/src/building_blocks.jl b/src/building_blocks.jl
index a2cab1d..4d9d459 100644
--- a/src/building_blocks.jl
+++ b/src/building_blocks.jl
@@ -2,7 +2,7 @@ module BuildingBlocks
 import JuMP: value, Model, @constraint, @objective, objective_function, AbstractJuMPScalar
 import Printf: @sprintf
 import ..Variables: variables, start_time, duration, end_time, gradient_strength, slew_rate, effective_time, VariableType, qval_square
-import ..BuildSequences: global_model, global_scanner
+import ..BuildSequences: global_model, global_scanner, fixed
 import ..Scanners: Scanner
 
 """
@@ -10,7 +10,6 @@ Parent type for all individual components out of which a sequence can be built.
 
 Required methods:
 - [`duration`](@ref)(block, parameters): Return block duration in ms.
-- [`fixed`](block): Return the equivalent fixed BuildingBlock (i.e., `FixedBlock`, `FixedPulse`, `FixedGradient`, `FixedInstantPulse`, `FixedInstantGradient`, or `InstantReadout`). These all have in common that they have no free variables and explicitly set any gradient and RF pulse profiles.
 - [`variables`](@ref): A list of all functions that are used to compute variables of the building block. Any of these can be used in constraints or objective functions.
 """
 abstract type BuildingBlock end
@@ -101,17 +100,6 @@ Function used internally to convert a wide variety of objects into [`BuildingBlo
 to_block(bb::BuildingBlock) = bb
 
 
-"""
-    fixed(block::BuildingBlock)
-
-Return the fixed equivalent of the `BuildingBlock`
-
-Possible return types are `FixedSequence`, `FixedBlock`, `FixedPulse`, `FixedGradient`, `FixedInstantPulse`, `FixedInstantGradient`, or `InstantReadout`. 
-These all have in common that they have no free variables and explicitly set any gradient and RF pulse profiles.
-"""
-function fixed end
-
-
 
 """
     variables(building_block)
@@ -251,7 +239,7 @@ apply_simple_constraint!(variable, ::Nothing) = nothing
 apply_simple_constraint!(variable, value::Symbol) = apply_simple_constraint!(variable, Val(value))
 apply_simple_constraint!(variable, ::Val{:min}) = @objective global_model() Min objective_function(global_model()) + variable
 apply_simple_constraint!(variable, ::Val{:max}) = @objective global_model() Min objective_function(global_model()) - variable
-apply_simple_constraint!(variable, value::VariableType) = @constraint model variable == value
+apply_simple_constraint!(variable, value::VariableType) = @constraint global_model() variable == value
 apply_simple_constraint!(variable::AbstractVector, value::AbstractVector) = [apply_simple_constraint!(v1, v2) for (v1, v2) in zip(variable, value)]
 
 
@@ -318,4 +306,13 @@ function scanner_constraints!(building_block::BuildingBlock, scanner::Scanner, f
     end
 end
 
+
+function fixed(bb::BuildingBlock)
+    arguments = []
+    for name in propertynames(bb)
+        push!(arguments, fixed(getproperty(bb, name)))
+    end
+    return typeof(bb)(arguments...)
+end
+
 end
\ No newline at end of file
diff --git a/src/gradients/changing_gradient_blocks.jl b/src/gradients/changing_gradient_blocks.jl
index cf2fd57..5ecba8a 100644
--- a/src/gradients/changing_gradient_blocks.jl
+++ b/src/gradients/changing_gradient_blocks.jl
@@ -3,7 +3,7 @@ import StaticArrays: SVector
 import ...Variables: VariableType, variables, get_free_variable
 import ...BuildingBlocks: GradientBlock
 import ...Variables: qvec, bmat_gradient, gradient_strength, slew_rate, duration, variables, VariableType
-import ...BuildingBlocks: GradientBlock, fixed, RFPulseBlock
+import ...BuildingBlocks: GradientBlock, RFPulseBlock
 
 """
     ChangingGradientBlock(grad1, slew_rate, duration, rotate, scale)
diff --git a/src/gradients/constant_gradient_blocks.jl b/src/gradients/constant_gradient_blocks.jl
index 61e32f0..858081c 100644
--- a/src/gradients/constant_gradient_blocks.jl
+++ b/src/gradients/constant_gradient_blocks.jl
@@ -3,7 +3,7 @@ import StaticArrays: SVector
 import ...Variables: VariableType, variables
 import ...BuildingBlocks: GradientBlock
 import ...Variables: qvec, bmat_gradient, gradient_strength, slew_rate, duration, variables, VariableType
-import ...BuildingBlocks: GradientBlock, fixed, RFPulseBlock
+import ...BuildingBlocks: GradientBlock, RFPulseBlock
 import ..ChangingGradientBlocks: split_gradient
 
 """
diff --git a/src/overlapping/spoilt_slice_selects.jl b/src/overlapping/spoilt_slice_selects.jl
index e218625..2b0a684 100644
--- a/src/overlapping/spoilt_slice_selects.jl
+++ b/src/overlapping/spoilt_slice_selects.jl
@@ -66,7 +66,7 @@ function SpoiltSliceSelect(pulse; orientation=[0, 0, 1], rotate=nothing, spoiler
             @constraint model (qvec(res, nothing, 1)[dim] ./ res.orientation[dim]) >= 2Ï€ * 1e-3 / spoiler_scale
         end
     end
-    set_simple_constraints!(model, res, kwargs)
+    set_simple_constraints!(res, kwargs)
 
     max_time = gradient_strength(global_scanner()) / res.slew_rate
     @constraint model rise_time(res)[1] <= max_time
diff --git a/src/overlapping/trapezoid_gradients.jl b/src/overlapping/trapezoid_gradients.jl
index 11b5d86..0665ca6 100644
--- a/src/overlapping/trapezoid_gradients.jl
+++ b/src/overlapping/trapezoid_gradients.jl
@@ -7,7 +7,7 @@ import JuMP: @constraint, @variable, VariableRef, value
 import StaticArrays: SVector
 import LinearAlgebra: norm
 import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType, inverse_slice_thickness, inverse_bandwidth, effective_time
-import ...BuildingBlocks: duration, set_simple_constraints!, fixed, RFPulseBlock, scanner_constraints!
+import ...BuildingBlocks: duration, set_simple_constraints!, RFPulseBlock, scanner_constraints!
 import ...BuildSequences: global_model
 import ...Gradients: ChangingGradientBlock, ConstantGradientBlock
 import ..Abstract: interruptions, waveform, AbstractOverlapping
diff --git a/src/pulses/constant_pulses.jl b/src/pulses/constant_pulses.jl
index cf069cb..07ac767 100644
--- a/src/pulses/constant_pulses.jl
+++ b/src/pulses/constant_pulses.jl
@@ -1,9 +1,8 @@
 module ConstantPulses
 import JuMP: VariableRef, @constraint, @variable, value
-import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!, fixed
+import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!
 import ...Variables: variables, get_free_variable, flip_angle, phase, amplitude, frequency, start_time, end_time, VariableType, duration, effective_time, inverse_bandwidth
 import ...BuildSequences: global_model
-import ..FixedPulses: FixedPulse
 
 """
     ConstantPulse(; variables...)
diff --git a/src/pulses/instant_pulses.jl b/src/pulses/instant_pulses.jl
index 4dafb7d..bb8e195 100644
--- a/src/pulses/instant_pulses.jl
+++ b/src/pulses/instant_pulses.jl
@@ -1,6 +1,6 @@
 module InstantPulses
 import JuMP: @constraint, @variable, VariableRef, value
-import ...BuildingBlocks: RFPulseBlock, fixed
+import ...BuildingBlocks: RFPulseBlock
 import ...Variables: flip_angle, phase, start_time, variables, duration, get_free_variable, VariableType, effective_time, inverse_bandwidth
 import ...BuildSequences: global_model
 import ..FixedPulses: FixedInstantPulse
diff --git a/src/pulses/sinc_pulses.jl b/src/pulses/sinc_pulses.jl
index 35db4b2..257e4ef 100644
--- a/src/pulses/sinc_pulses.jl
+++ b/src/pulses/sinc_pulses.jl
@@ -3,7 +3,7 @@ module SincPulses
 import JuMP: VariableRef, @constraint, @variable, value
 import QuadGK: quadgk
 import Polynomials: fit, Polynomial
-import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!, fixed
+import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!
 import ...Variables: flip_angle, phase, amplitude, frequency, VariableType, variables, get_free_variable, duration, effective_time, inverse_bandwidth
 import ...BuildSequences: global_model
 import ..FixedPulses: FixedPulse
@@ -69,7 +69,7 @@ function SincPulse(;
     if !symmetric
         @constraint model res.N_right >= 1
     end
-    set_simple_constraints!(model, res, kwargs)
+    set_simple_constraints!(res, kwargs)
     return res
 end
 
diff --git a/src/readouts/instant_readouts.jl b/src/readouts/instant_readouts.jl
index 21e74a6..9bf9a29 100644
--- a/src/readouts/instant_readouts.jl
+++ b/src/readouts/instant_readouts.jl
@@ -1,5 +1,5 @@
 module InstantReadouts
-import ...BuildingBlocks: BuildingBlock, to_block, fixed, effective_time
+import ...BuildingBlocks: BuildingBlock, to_block, effective_time
 import ...Variables: variables, duration
 
 """
@@ -15,6 +15,5 @@ end
 variables(::Type{<:InstantReadout}) = []
 to_block(::Type{<:InstantReadout}) = InstantReadout()
 duration(::InstantReadout) = 0.
-fixed(i::InstantReadout) = i
 effective_time(::InstantReadout) = 0.
 end
\ No newline at end of file
diff --git a/src/sequences.jl b/src/sequences.jl
index bd5910c..b66225d 100644
--- a/src/sequences.jl
+++ b/src/sequences.jl
@@ -6,7 +6,7 @@ import Printf: @sprintf
 import JuMP: @constraint
 import ...BuildSequences: global_model
 import ...Variables: variables, start_time, duration, VariableType, get_free_variable, TR, end_time
-import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks
+import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks, fixed
 
 """
     Sequence(building_blocks...; TR=nothing)
@@ -36,6 +36,7 @@ struct Sequence <: ContainerBlock
 end
 
 Sequence(blocks...; TR=nothing) = Sequence([blocks...]; TR=TR)
+fixed(seq::Sequence) = fixed.(seq._blocks, TR=fixed(seq.TR))
 
 Base.length(seq::Sequence) = length(seq._blocks)
 Base.getindex(seq::Sequence, index) = seq._blocks[index]
-- 
GitLab