From bd39a767a7f1001e67e01214ecbe5af4284bd3e4 Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk> Date: Thu, 1 Feb 2024 14:55:22 +0000 Subject: [PATCH] Set scanner at a global level --- src/build_sequences.jl | 17 ++++++++++++++--- src/overlapping/generic.jl | 6 ++++++ src/overlapping/trapezoid_gradients.jl | 2 ++ src/scanners.jl | 5 +++-- src/sequences.jl | 15 +++++---------- 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/build_sequences.jl b/src/build_sequences.jl index df993dc..5182ce9 100644 --- a/src/build_sequences.jl +++ b/src/build_sequences.jl @@ -2,10 +2,14 @@ module BuildSequences import JuMP: Model, optimizer_with_attributes, optimize! import Ipopt import Juniper +import ..Scanners: Scanner, scanner_constraints! +import ..BuildingBlocks: BuildingBlock const GLOBAL_MODEL = Ref(Model()) const IGNORE_MODEL = GLOBAL_MODEL[] +const GLOBAL_SCANNER = Ref(Scanner()) + """ Wrapper to build a sequence. @@ -31,23 +35,28 @@ You can also add any arbitrary constraints or objectives using the same syntax a As soon as the code block is the optimal sequence matching all your constraints and objectives will be returned. """ -function build_sequence(f::Function, model::Model) +function build_sequence(f::Function, scanner::Scanner, model::Model) prev_model = GLOBAL_MODEL[] GLOBAL_MODEL[] = model + prev_scanner = GLOBAL_SCANNER[] + if !isnothing(scanner) + GLOBAL_SCANNER[] = scanner + end try sequence = f(model) optimize!(model) return sequence finally GLOBAL_MODEL[] = prev_model + GLOBAL_SCANNER[] = prev_scanner end end -function build_sequence(f::Function) +function build_sequence(f::Function, scanner::Scanner) ipopt_opt = optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 3) juniper_opt = optimizer_with_attributes(Juniper.Optimizer, "nl_solver" => ipopt_opt) model = Model(juniper_opt) - build_sequence(f, model) + build_sequence(f, scanner, model) end @@ -58,6 +67,8 @@ function get_global_model() return GLOBAL_MODEL[] end +scanner_constraints!(bb::BuildingBlock) = scanner_constraints!(bb, GLOBAL_SCANNER[]) + """ @global_model_constructor BuildingBlockType diff --git a/src/overlapping/generic.jl b/src/overlapping/generic.jl index 2874db5..433437f 100644 --- a/src/overlapping/generic.jl +++ b/src/overlapping/generic.jl @@ -5,6 +5,7 @@ import ...Wait: WaitBlock import ...Readouts: InstantReadout import ...Pulses: RFPulseBlock import ...Gradients: GradientBlock +import ...Scanners: scanner_constraints! """ GenericWaveform(duration, waveform, interruptions) @@ -20,6 +21,11 @@ struct GenericOverlapping <: AbstractOverlapping duration :: VariableType waveform :: Vector{Union{WaitBlock, GradientBlock}} interruptions :: Vector{NamedTuple{(:index, :time, :block), Tuple{Int64, <:VariableType, <:Union{RFPulseBlock, InstantReadout}}}} + function GenericOverlapping(duration::VariableType, waveform::AbstractVector, interruptions::AbstractVector=[]) + res = new(duration, waveform, interruptions) + scanner_constraints!.(waveform) + return res + end end diff --git a/src/overlapping/trapezoid_gradients.jl b/src/overlapping/trapezoid_gradients.jl index 39fb32e..f5c1f91 100644 --- a/src/overlapping/trapezoid_gradients.jl +++ b/src/overlapping/trapezoid_gradients.jl @@ -10,6 +10,7 @@ import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, v import ...BuildingBlocks: duration, set_simple_constraints!, fixed, RFPulseBlock import ...BuildSequences: @global_model_constructor import ...Gradients: ChangingGradientBlock, ConstantGradientBlock +import ...Scanners: scanner_constraints! import ..Abstract: interruptions, waveform, AbstractOverlapping @@ -112,6 +113,7 @@ function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing, set_simple_constraints!(model, res, kwargs) @constraint model flat_time >= 0 @constraint model rise_time >= 0 + scanner_constraints!(res) return res end diff --git a/src/scanners.jl b/src/scanners.jl index e707a0d..b5ddf9a 100644 --- a/src/scanners.jl +++ b/src/scanners.jl @@ -4,7 +4,8 @@ Define general [`Scanner`](@ref) type and methods as well as some concrete scann module Scanners import JuMP: Model, @constraint, owner_model import ..Variables: gradient_strength, slew_rate -import ..BuildingBlocks: BuildingBlock, get_children_blocks +import ..BuildingBlocks: BuildingBlock, get_children_blocks, ContainerBlock +import ..Variables: variables const gyromagnetic_ratio = 42576.38476 # (kHz/T) @@ -88,7 +89,7 @@ predefined_scanners = Dict( ) """ - scanner_constraints!([model, ]building_block, scanner) + scanner_constraints!([model, ]building_block[, scanner]) Adds any constraints from a specific scanner to a [`BuildingBlock`]{@ref}. """ diff --git a/src/sequences.jl b/src/sequences.jl index 37a7af1..e6f2947 100644 --- a/src/sequences.jl +++ b/src/sequences.jl @@ -6,16 +6,14 @@ import Printf: @sprintf import JuMP: Model, @constraint import ...BuildSequences: @global_model_constructor import ...Variables: variables, start_time, duration, VariableType, get_free_variable, TR, end_time -import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, scanner_constraints!, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks +import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks """ - Sequence(building_blocks...; TR=nothing, scanner=nothing) - Sequence([building_blocks]; TR=nothing, scanner=nothing) + Sequence(building_blocks...; TR=nothing) + Sequence([building_blocks]; TR=nothing) Represents a series of [`BuildingBlock`](@ref) objects run in turn. -Providing a [`Scanner`](@ref) will lead to [`scanner_constraints!`](@ref) to be called on all building blocks. - This can be used as a top-level NMR/MRI sequence (in which case the [`TR`](@ref) variable is relevant) or be embedded as a [`BuildingBlock`](@ref) into higher-order `Sequence` or other [`ContainerBlock`](@ref) objects. @@ -26,7 +24,7 @@ struct Sequence <: ContainerBlock model :: Model _blocks :: Vector{<:BuildingBlock} TR :: VariableType - function Sequence(model::Model, blocks::AbstractVector; TR=nothing, scanner=nothing) + function Sequence(model::Model, blocks::AbstractVector; TR=nothing) seq = new( model, to_block.(blocks), @@ -35,16 +33,13 @@ struct Sequence <: ContainerBlock if !(TR isa Number && isinf(TR)) @constraint model seq.TR >= duration(seq) end - if !isnothing(scanner) - scanner_constraints!(model, seq, scanner) - end return seq end end @global_model_constructor Sequence -Sequence(model::Model, blocks...; TR=nothing, scanner=nothing) = Sequence(model, [blocks...]; TR=TR, scanner=scanner) +Sequence(model::Model, blocks...; TR=nothing) = Sequence(model, [blocks...]; TR=TR) Base.length(seq::Sequence) = length(seq._blocks) Base.getindex(seq::Sequence, index) = seq._blocks[index] -- GitLab