From 76eb80627d71172dcc0ca9001eec69b3a5300c7d Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk> Date: Fri, 2 Feb 2024 10:57:00 +0000 Subject: [PATCH] Replace variables with numbers after optimisation --- src/build_sequences.jl | 17 +++++++++++++-- src/building_blocks.jl | 25 ++++++++++------------- src/gradients/changing_gradient_blocks.jl | 2 +- src/gradients/constant_gradient_blocks.jl | 2 +- src/overlapping/spoilt_slice_selects.jl | 2 +- src/overlapping/trapezoid_gradients.jl | 2 +- src/pulses/constant_pulses.jl | 3 +-- src/pulses/instant_pulses.jl | 2 +- src/pulses/sinc_pulses.jl | 4 ++-- src/readouts/instant_readouts.jl | 3 +-- src/sequences.jl | 3 ++- 11 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/build_sequences.jl b/src/build_sequences.jl index 5600724..461e351 100644 --- a/src/build_sequences.jl +++ b/src/build_sequences.jl @@ -1,5 +1,5 @@ module BuildSequences -import JuMP: Model, optimizer_with_attributes, optimize! +import JuMP: Model, optimizer_with_attributes, optimize!, AbstractJuMPScalar, value import Ipopt import Juniper import ..Scanners: Scanner @@ -44,7 +44,7 @@ function build_sequence(f::Function, scanner::Scanner, model::Model) try sequence = f(model) optimize!(model) - return sequence + return fixed(sequence) finally GLOBAL_MODEL[] = prev_model GLOBAL_SCANNER[] = prev_scanner @@ -74,4 +74,17 @@ function global_scanner() end +""" + fixed(building_block) + +Returns an equiavalent [`BuildingBlock`](@ref) with all free variables replaced by numbers. +This will only work after calling [`optimize!`](@ref)([`global_model`](@ref)()). +It is used internally by [`build_sequence`](@ref). +""" +fixed(some_value) = some_value +fixed(jump_variable::AbstractJuMPScalar) = value(jump_variable) +fixed(jump_variable::AbstractArray) = fixed.(jump_variable) + + + end \ No newline at end of file diff --git a/src/building_blocks.jl b/src/building_blocks.jl index a2cab1d..4d9d459 100644 --- a/src/building_blocks.jl +++ b/src/building_blocks.jl @@ -2,7 +2,7 @@ module BuildingBlocks import JuMP: value, Model, @constraint, @objective, objective_function, AbstractJuMPScalar import Printf: @sprintf import ..Variables: variables, start_time, duration, end_time, gradient_strength, slew_rate, effective_time, VariableType, qval_square -import ..BuildSequences: global_model, global_scanner +import ..BuildSequences: global_model, global_scanner, fixed import ..Scanners: Scanner """ @@ -10,7 +10,6 @@ Parent type for all individual components out of which a sequence can be built. Required methods: - [`duration`](@ref)(block, parameters): Return block duration in ms. -- [`fixed`](block): Return the equivalent fixed BuildingBlock (i.e., `FixedBlock`, `FixedPulse`, `FixedGradient`, `FixedInstantPulse`, `FixedInstantGradient`, or `InstantReadout`). These all have in common that they have no free variables and explicitly set any gradient and RF pulse profiles. - [`variables`](@ref): A list of all functions that are used to compute variables of the building block. Any of these can be used in constraints or objective functions. """ abstract type BuildingBlock end @@ -101,17 +100,6 @@ Function used internally to convert a wide variety of objects into [`BuildingBlo to_block(bb::BuildingBlock) = bb -""" - fixed(block::BuildingBlock) - -Return the fixed equivalent of the `BuildingBlock` - -Possible return types are `FixedSequence`, `FixedBlock`, `FixedPulse`, `FixedGradient`, `FixedInstantPulse`, `FixedInstantGradient`, or `InstantReadout`. -These all have in common that they have no free variables and explicitly set any gradient and RF pulse profiles. -""" -function fixed end - - """ variables(building_block) @@ -251,7 +239,7 @@ apply_simple_constraint!(variable, ::Nothing) = nothing apply_simple_constraint!(variable, value::Symbol) = apply_simple_constraint!(variable, Val(value)) apply_simple_constraint!(variable, ::Val{:min}) = @objective global_model() Min objective_function(global_model()) + variable apply_simple_constraint!(variable, ::Val{:max}) = @objective global_model() Min objective_function(global_model()) - variable -apply_simple_constraint!(variable, value::VariableType) = @constraint model variable == value +apply_simple_constraint!(variable, value::VariableType) = @constraint global_model() variable == value apply_simple_constraint!(variable::AbstractVector, value::AbstractVector) = [apply_simple_constraint!(v1, v2) for (v1, v2) in zip(variable, value)] @@ -318,4 +306,13 @@ function scanner_constraints!(building_block::BuildingBlock, scanner::Scanner, f end end + +function fixed(bb::BuildingBlock) + arguments = [] + for name in propertynames(bb) + push!(arguments, fixed(getproperty(bb, name))) + end + return typeof(bb)(arguments...) +end + end \ No newline at end of file diff --git a/src/gradients/changing_gradient_blocks.jl b/src/gradients/changing_gradient_blocks.jl index cf2fd57..5ecba8a 100644 --- a/src/gradients/changing_gradient_blocks.jl +++ b/src/gradients/changing_gradient_blocks.jl @@ -3,7 +3,7 @@ import StaticArrays: SVector import ...Variables: VariableType, variables, get_free_variable import ...BuildingBlocks: GradientBlock import ...Variables: qvec, bmat_gradient, gradient_strength, slew_rate, duration, variables, VariableType -import ...BuildingBlocks: GradientBlock, fixed, RFPulseBlock +import ...BuildingBlocks: GradientBlock, RFPulseBlock """ ChangingGradientBlock(grad1, slew_rate, duration, rotate, scale) diff --git a/src/gradients/constant_gradient_blocks.jl b/src/gradients/constant_gradient_blocks.jl index 61e32f0..858081c 100644 --- a/src/gradients/constant_gradient_blocks.jl +++ b/src/gradients/constant_gradient_blocks.jl @@ -3,7 +3,7 @@ import StaticArrays: SVector import ...Variables: VariableType, variables import ...BuildingBlocks: GradientBlock import ...Variables: qvec, bmat_gradient, gradient_strength, slew_rate, duration, variables, VariableType -import ...BuildingBlocks: GradientBlock, fixed, RFPulseBlock +import ...BuildingBlocks: GradientBlock, RFPulseBlock import ..ChangingGradientBlocks: split_gradient """ diff --git a/src/overlapping/spoilt_slice_selects.jl b/src/overlapping/spoilt_slice_selects.jl index e218625..2b0a684 100644 --- a/src/overlapping/spoilt_slice_selects.jl +++ b/src/overlapping/spoilt_slice_selects.jl @@ -66,7 +66,7 @@ function SpoiltSliceSelect(pulse; orientation=[0, 0, 1], rotate=nothing, spoiler @constraint model (qvec(res, nothing, 1)[dim] ./ res.orientation[dim]) >= 2π * 1e-3 / spoiler_scale end end - set_simple_constraints!(model, res, kwargs) + set_simple_constraints!(res, kwargs) max_time = gradient_strength(global_scanner()) / res.slew_rate @constraint model rise_time(res)[1] <= max_time diff --git a/src/overlapping/trapezoid_gradients.jl b/src/overlapping/trapezoid_gradients.jl index 11b5d86..0665ca6 100644 --- a/src/overlapping/trapezoid_gradients.jl +++ b/src/overlapping/trapezoid_gradients.jl @@ -7,7 +7,7 @@ import JuMP: @constraint, @variable, VariableRef, value import StaticArrays: SVector import LinearAlgebra: norm import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType, inverse_slice_thickness, inverse_bandwidth, effective_time -import ...BuildingBlocks: duration, set_simple_constraints!, fixed, RFPulseBlock, scanner_constraints! +import ...BuildingBlocks: duration, set_simple_constraints!, RFPulseBlock, scanner_constraints! import ...BuildSequences: global_model import ...Gradients: ChangingGradientBlock, ConstantGradientBlock import ..Abstract: interruptions, waveform, AbstractOverlapping diff --git a/src/pulses/constant_pulses.jl b/src/pulses/constant_pulses.jl index cf069cb..07ac767 100644 --- a/src/pulses/constant_pulses.jl +++ b/src/pulses/constant_pulses.jl @@ -1,9 +1,8 @@ module ConstantPulses import JuMP: VariableRef, @constraint, @variable, value -import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!, fixed +import ...BuildingBlocks: RFPulseBlock, set_simple_constraints! import ...Variables: variables, get_free_variable, flip_angle, phase, amplitude, frequency, start_time, end_time, VariableType, duration, effective_time, inverse_bandwidth import ...BuildSequences: global_model -import ..FixedPulses: FixedPulse """ ConstantPulse(; variables...) diff --git a/src/pulses/instant_pulses.jl b/src/pulses/instant_pulses.jl index 4dafb7d..bb8e195 100644 --- a/src/pulses/instant_pulses.jl +++ b/src/pulses/instant_pulses.jl @@ -1,6 +1,6 @@ module InstantPulses import JuMP: @constraint, @variable, VariableRef, value -import ...BuildingBlocks: RFPulseBlock, fixed +import ...BuildingBlocks: RFPulseBlock import ...Variables: flip_angle, phase, start_time, variables, duration, get_free_variable, VariableType, effective_time, inverse_bandwidth import ...BuildSequences: global_model import ..FixedPulses: FixedInstantPulse diff --git a/src/pulses/sinc_pulses.jl b/src/pulses/sinc_pulses.jl index 35db4b2..257e4ef 100644 --- a/src/pulses/sinc_pulses.jl +++ b/src/pulses/sinc_pulses.jl @@ -3,7 +3,7 @@ module SincPulses import JuMP: VariableRef, @constraint, @variable, value import QuadGK: quadgk import Polynomials: fit, Polynomial -import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!, fixed +import ...BuildingBlocks: RFPulseBlock, set_simple_constraints! import ...Variables: flip_angle, phase, amplitude, frequency, VariableType, variables, get_free_variable, duration, effective_time, inverse_bandwidth import ...BuildSequences: global_model import ..FixedPulses: FixedPulse @@ -69,7 +69,7 @@ function SincPulse(; if !symmetric @constraint model res.N_right >= 1 end - set_simple_constraints!(model, res, kwargs) + set_simple_constraints!(res, kwargs) return res end diff --git a/src/readouts/instant_readouts.jl b/src/readouts/instant_readouts.jl index 21e74a6..9bf9a29 100644 --- a/src/readouts/instant_readouts.jl +++ b/src/readouts/instant_readouts.jl @@ -1,5 +1,5 @@ module InstantReadouts -import ...BuildingBlocks: BuildingBlock, to_block, fixed, effective_time +import ...BuildingBlocks: BuildingBlock, to_block, effective_time import ...Variables: variables, duration """ @@ -15,6 +15,5 @@ end variables(::Type{<:InstantReadout}) = [] to_block(::Type{<:InstantReadout}) = InstantReadout() duration(::InstantReadout) = 0. -fixed(i::InstantReadout) = i effective_time(::InstantReadout) = 0. end \ No newline at end of file diff --git a/src/sequences.jl b/src/sequences.jl index bd5910c..b66225d 100644 --- a/src/sequences.jl +++ b/src/sequences.jl @@ -6,7 +6,7 @@ import Printf: @sprintf import JuMP: @constraint import ...BuildSequences: global_model import ...Variables: variables, start_time, duration, VariableType, get_free_variable, TR, end_time -import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks +import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks, fixed """ Sequence(building_blocks...; TR=nothing) @@ -36,6 +36,7 @@ struct Sequence <: ContainerBlock end Sequence(blocks...; TR=nothing) = Sequence([blocks...]; TR=TR) +fixed(seq::Sequence) = fixed.(seq._blocks, TR=fixed(seq.TR)) Base.length(seq::Sequence) = length(seq._blocks) Base.getindex(seq::Sequence, index) = seq._blocks[index] -- GitLab