Skip to content
Snippets Groups Projects
Verified Commit bd39a767 authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

Set scanner at a global level

parent 97991aaa
No related branches found
No related tags found
No related merge requests found
...@@ -2,10 +2,14 @@ module BuildSequences ...@@ -2,10 +2,14 @@ module BuildSequences
import JuMP: Model, optimizer_with_attributes, optimize! import JuMP: Model, optimizer_with_attributes, optimize!
import Ipopt import Ipopt
import Juniper import Juniper
import ..Scanners: Scanner, scanner_constraints!
import ..BuildingBlocks: BuildingBlock
const GLOBAL_MODEL = Ref(Model()) const GLOBAL_MODEL = Ref(Model())
const IGNORE_MODEL = GLOBAL_MODEL[] const IGNORE_MODEL = GLOBAL_MODEL[]
const GLOBAL_SCANNER = Ref(Scanner())
""" """
Wrapper to build a sequence. Wrapper to build a sequence.
...@@ -31,23 +35,28 @@ You can also add any arbitrary constraints or objectives using the same syntax a ...@@ -31,23 +35,28 @@ You can also add any arbitrary constraints or objectives using the same syntax a
As soon as the code block is the optimal sequence matching all your constraints and objectives will be returned. As soon as the code block is the optimal sequence matching all your constraints and objectives will be returned.
""" """
function build_sequence(f::Function, model::Model) function build_sequence(f::Function, scanner::Scanner, model::Model)
prev_model = GLOBAL_MODEL[] prev_model = GLOBAL_MODEL[]
GLOBAL_MODEL[] = model GLOBAL_MODEL[] = model
prev_scanner = GLOBAL_SCANNER[]
if !isnothing(scanner)
GLOBAL_SCANNER[] = scanner
end
try try
sequence = f(model) sequence = f(model)
optimize!(model) optimize!(model)
return sequence return sequence
finally finally
GLOBAL_MODEL[] = prev_model GLOBAL_MODEL[] = prev_model
GLOBAL_SCANNER[] = prev_scanner
end end
end end
function build_sequence(f::Function) function build_sequence(f::Function, scanner::Scanner)
ipopt_opt = optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 3) ipopt_opt = optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 3)
juniper_opt = optimizer_with_attributes(Juniper.Optimizer, "nl_solver" => ipopt_opt) juniper_opt = optimizer_with_attributes(Juniper.Optimizer, "nl_solver" => ipopt_opt)
model = Model(juniper_opt) model = Model(juniper_opt)
build_sequence(f, model) build_sequence(f, scanner, model)
end end
...@@ -58,6 +67,8 @@ function get_global_model() ...@@ -58,6 +67,8 @@ function get_global_model()
return GLOBAL_MODEL[] return GLOBAL_MODEL[]
end end
scanner_constraints!(bb::BuildingBlock) = scanner_constraints!(bb, GLOBAL_SCANNER[])
""" """
@global_model_constructor BuildingBlockType @global_model_constructor BuildingBlockType
......
...@@ -5,6 +5,7 @@ import ...Wait: WaitBlock ...@@ -5,6 +5,7 @@ import ...Wait: WaitBlock
import ...Readouts: InstantReadout import ...Readouts: InstantReadout
import ...Pulses: RFPulseBlock import ...Pulses: RFPulseBlock
import ...Gradients: GradientBlock import ...Gradients: GradientBlock
import ...Scanners: scanner_constraints!
""" """
GenericWaveform(duration, waveform, interruptions) GenericWaveform(duration, waveform, interruptions)
...@@ -20,6 +21,11 @@ struct GenericOverlapping <: AbstractOverlapping ...@@ -20,6 +21,11 @@ struct GenericOverlapping <: AbstractOverlapping
duration :: VariableType duration :: VariableType
waveform :: Vector{Union{WaitBlock, GradientBlock}} waveform :: Vector{Union{WaitBlock, GradientBlock}}
interruptions :: Vector{NamedTuple{(:index, :time, :block), Tuple{Int64, <:VariableType, <:Union{RFPulseBlock, InstantReadout}}}} interruptions :: Vector{NamedTuple{(:index, :time, :block), Tuple{Int64, <:VariableType, <:Union{RFPulseBlock, InstantReadout}}}}
function GenericOverlapping(duration::VariableType, waveform::AbstractVector, interruptions::AbstractVector=[])
res = new(duration, waveform, interruptions)
scanner_constraints!.(waveform)
return res
end
end end
......
...@@ -10,6 +10,7 @@ import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, v ...@@ -10,6 +10,7 @@ import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, v
import ...BuildingBlocks: duration, set_simple_constraints!, fixed, RFPulseBlock import ...BuildingBlocks: duration, set_simple_constraints!, fixed, RFPulseBlock
import ...BuildSequences: @global_model_constructor import ...BuildSequences: @global_model_constructor
import ...Gradients: ChangingGradientBlock, ConstantGradientBlock import ...Gradients: ChangingGradientBlock, ConstantGradientBlock
import ...Scanners: scanner_constraints!
import ..Abstract: interruptions, waveform, AbstractOverlapping import ..Abstract: interruptions, waveform, AbstractOverlapping
...@@ -112,6 +113,7 @@ function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing, ...@@ -112,6 +113,7 @@ function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing,
set_simple_constraints!(model, res, kwargs) set_simple_constraints!(model, res, kwargs)
@constraint model flat_time >= 0 @constraint model flat_time >= 0
@constraint model rise_time >= 0 @constraint model rise_time >= 0
scanner_constraints!(res)
return res return res
end end
......
...@@ -4,7 +4,8 @@ Define general [`Scanner`](@ref) type and methods as well as some concrete scann ...@@ -4,7 +4,8 @@ Define general [`Scanner`](@ref) type and methods as well as some concrete scann
module Scanners module Scanners
import JuMP: Model, @constraint, owner_model import JuMP: Model, @constraint, owner_model
import ..Variables: gradient_strength, slew_rate import ..Variables: gradient_strength, slew_rate
import ..BuildingBlocks: BuildingBlock, get_children_blocks import ..BuildingBlocks: BuildingBlock, get_children_blocks, ContainerBlock
import ..Variables: variables
const gyromagnetic_ratio = 42576.38476 # (kHz/T) const gyromagnetic_ratio = 42576.38476 # (kHz/T)
...@@ -88,7 +89,7 @@ predefined_scanners = Dict( ...@@ -88,7 +89,7 @@ predefined_scanners = Dict(
) )
""" """
scanner_constraints!([model, ]building_block, scanner) scanner_constraints!([model, ]building_block[, scanner])
Adds any constraints from a specific scanner to a [`BuildingBlock`]{@ref}. Adds any constraints from a specific scanner to a [`BuildingBlock`]{@ref}.
""" """
......
...@@ -6,16 +6,14 @@ import Printf: @sprintf ...@@ -6,16 +6,14 @@ import Printf: @sprintf
import JuMP: Model, @constraint import JuMP: Model, @constraint
import ...BuildSequences: @global_model_constructor import ...BuildSequences: @global_model_constructor
import ...Variables: variables, start_time, duration, VariableType, get_free_variable, TR, end_time import ...Variables: variables, start_time, duration, VariableType, get_free_variable, TR, end_time
import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, scanner_constraints!, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks
""" """
Sequence(building_blocks...; TR=nothing, scanner=nothing) Sequence(building_blocks...; TR=nothing)
Sequence([building_blocks]; TR=nothing, scanner=nothing) Sequence([building_blocks]; TR=nothing)
Represents a series of [`BuildingBlock`](@ref) objects run in turn. Represents a series of [`BuildingBlock`](@ref) objects run in turn.
Providing a [`Scanner`](@ref) will lead to [`scanner_constraints!`](@ref) to be called on all building blocks.
This can be used as a top-level NMR/MRI sequence (in which case the [`TR`](@ref) variable is relevant) This can be used as a top-level NMR/MRI sequence (in which case the [`TR`](@ref) variable is relevant)
or be embedded as a [`BuildingBlock`](@ref) into higher-order `Sequence` or other [`ContainerBlock`](@ref) objects. or be embedded as a [`BuildingBlock`](@ref) into higher-order `Sequence` or other [`ContainerBlock`](@ref) objects.
...@@ -26,7 +24,7 @@ struct Sequence <: ContainerBlock ...@@ -26,7 +24,7 @@ struct Sequence <: ContainerBlock
model :: Model model :: Model
_blocks :: Vector{<:BuildingBlock} _blocks :: Vector{<:BuildingBlock}
TR :: VariableType TR :: VariableType
function Sequence(model::Model, blocks::AbstractVector; TR=nothing, scanner=nothing) function Sequence(model::Model, blocks::AbstractVector; TR=nothing)
seq = new( seq = new(
model, model,
to_block.(blocks), to_block.(blocks),
...@@ -35,16 +33,13 @@ struct Sequence <: ContainerBlock ...@@ -35,16 +33,13 @@ struct Sequence <: ContainerBlock
if !(TR isa Number && isinf(TR)) if !(TR isa Number && isinf(TR))
@constraint model seq.TR >= duration(seq) @constraint model seq.TR >= duration(seq)
end end
if !isnothing(scanner)
scanner_constraints!(model, seq, scanner)
end
return seq return seq
end end
end end
@global_model_constructor Sequence @global_model_constructor Sequence
Sequence(model::Model, blocks...; TR=nothing, scanner=nothing) = Sequence(model, [blocks...]; TR=TR, scanner=scanner) Sequence(model::Model, blocks...; TR=nothing) = Sequence(model, [blocks...]; TR=TR)
Base.length(seq::Sequence) = length(seq._blocks) Base.length(seq::Sequence) = length(seq._blocks)
Base.getindex(seq::Sequence, index) = seq._blocks[index] Base.getindex(seq::Sequence, index) = seq._blocks[index]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment