From 6fb836020356d8ac831a163201e3b52b8ad8e72d Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk> Date: Fri, 2 Feb 2024 10:27:28 +0000 Subject: [PATCH] Replace model parameter with global_model --- src/MRIBuilder.jl | 8 ++-- src/build_sequences.jl | 29 +++++++-------- src/building_blocks.jl | 49 ++++++------------------- src/gradients/instant_gradients.jl | 11 ++---- src/helper_functions.jl | 6 +-- src/overlapping/spoilt_slice_selects.jl | 8 ++-- src/overlapping/trapezoid_gradients.jl | 32 +++++++--------- src/pulses/constant_pulses.jl | 16 +++----- src/pulses/instant_pulses.jl | 16 +++----- src/pulses/sinc_pulses.jl | 19 ++++------ src/scanners.jl | 10 ++--- src/sequences.jl | 16 +++----- src/variables.jl | 31 +++++++--------- 13 files changed, 95 insertions(+), 156 deletions(-) diff --git a/src/MRIBuilder.jl b/src/MRIBuilder.jl index c449a35..11be175 100644 --- a/src/MRIBuilder.jl +++ b/src/MRIBuilder.jl @@ -16,8 +16,8 @@ include("sequences.jl") include("pathways.jl") include("helper_functions.jl") -import .BuildSequences: build_sequence -export build_sequence +import .BuildSequences: build_sequence, global_model, global_scanner +export build_sequence, global_model, global_scanner import .Scanners: Scanner, B0, Siemens_Connectom, Siemens_Prisma, Siemens_Terra export Scanner, B0, Siemens_Connectom, Siemens_Prisma, Siemens_Terra @@ -52,7 +52,7 @@ export Pathway, duration_transverse, duration_dephase, bval, bmat import .HelperFunctions: excitation_pulse, refocus_pulse export excitation_pulse, refocus_pulse -import JuMP: @constraint, @objective, objective_function, optimize!, has_values, value, owner_model, Model -export @constraint, @objective, objective_function, optimize!, has_values, value, owner_model, Model +import JuMP: @constraint, @objective, objective_function, value, Model +export @constraint, @objective, objective_function, value, Model end diff --git a/src/build_sequences.jl b/src/build_sequences.jl index 70dfdd3..41deba6 100644 --- a/src/build_sequences.jl +++ b/src/build_sequences.jl @@ -3,7 +3,7 @@ import JuMP: Model, optimizer_with_attributes, optimize! import Ipopt import Juniper import ..Scanners: Scanner, scanner_constraints! -import ..BuildingBlocks: BuildingBlock +import ..BuildingBlocks: BuildingBlock, apply_simple_constraint!, match_blocks! const GLOBAL_MODEL = Ref(Model()) const IGNORE_MODEL = GLOBAL_MODEL[] @@ -60,29 +60,26 @@ function build_sequence(f::Function, scanner::Scanner) end -function get_global_model() +function global_model() if GLOBAL_MODEL[] == IGNORE_MODEL error("No global model has been set. Please explicitly set one in the constructor or set a global model using `set_model`.") end return GLOBAL_MODEL[] end -scanner_constraints!(bb::BuildingBlock) = scanner_constraints!(bb, GLOBAL_SCANNER[]) - - -""" - @global_model_constructor BuildingBlockType - -Add a constructor to the [`BuildingBlock`](@ref) subtype that fetches the global JuMP model (set by [`set_model`](@ref)) and assigns it to the first argument. -``` -BuildingBlockType(args...; kwargs...) = BuildingBlockType(global_model::JuMP.Model, args...; kwargs...) -``` -""" -macro global_model_constructor(bb) - quote - $(esc(bb))(args...; kwargs...) = $(esc(bb))(get_global_model(), args...; kwargs...) +function global_scanner() + if !isfinite(GLOBAL_SCANNER[].gradient) + error("No valid scanner has been set. Please provide one when calling `build_sequence`.") end + return GLOBAL_SCANNER[] end +scanner_constraints!(bb::BuildingBlock) = scanner_constraints!(bb, global_scanner()) + +apply_simple_constraint!(variable, value) = apply_simple_constraint!(global_model(), variable, value) +match_blocks!(block1::BuildingBlock, block2::BuildingBlock, property_list) = match_blocks!(global_model(), block1, block2, property_list) +scanner_constraints!(building_block::BuildingBlock, scanner::Scanner, func::Function) = scanner_constraints!(building_block, scanner, func) +scanner_constraints!(building_block::BuildingBlock) = scanner_constraints!(building_block, global_scanner()) + end \ No newline at end of file diff --git a/src/building_blocks.jl b/src/building_blocks.jl index 48588f5..3c278dd 100644 --- a/src/building_blocks.jl +++ b/src/building_blocks.jl @@ -1,5 +1,5 @@ module BuildingBlocks -import JuMP: has_values, value, Model, @constraint, @objective, owner_model, objective_function, optimize!, AbstractJuMPScalar +import JuMP: value, Model, @constraint, @objective, objective_function, AbstractJuMPScalar import Printf: @sprintf import ..Variables: variables, start_time, duration, end_time, gradient_strength, slew_rate, effective_time, VariableType, qval_square @@ -167,7 +167,7 @@ function Base.show(io::IO, printer::BuildingBlockPrinter) for name in propertynames(block) ft = fieldtype(typeof(block), name) if ( - ft in (VariableType, Model) || + ft == VariableType || (ft <: AbstractVector && eltype(ft) == VariableType) || string(name)[1] == '_' ) @@ -197,15 +197,15 @@ end """ - set_simple_constraints!(model, block, kwargs) + set_simple_constraints!(block, kwargs) -Add any constraints or objective functions to the variables of a [`BuildingBlock`](@ref) in the JuMP `model`. +Add any constraints or objective functions to the variables of a [`BuildingBlock`](@ref). Each keyword argument has to match one of the functions in [`variables`](@ref)(block). If set to a numeric value, a constraint will be added to fix the function value to that numeric value. If set to `:min` or `:max`, minimising or maximising this function will be added to the cost function. """ -function set_simple_constraints!(model::Model, block::BuildingBlock, kwargs) +function set_simple_constraints!(block::BuildingBlock, kwargs) to_funcs = Dict(nameof(fn) => fn for fn in variables(block)) invert_value(value::VariableType) = 1 / value @@ -217,14 +217,14 @@ function set_simple_constraints!(model::Model, block::BuildingBlock, kwargs) for (key, value) in kwargs if key in keys(to_funcs) - apply_simple_constraint!(model, to_funcs[key](block), value) + apply_simple_constraint!(to_funcs[key](block), value) else if key == :qval - apply_simple_constraint!(model, to_funcs[:qval_square](block), value isa VariableType ? value^2 : value) + apply_simple_constraint!(to_funcs[:qval_square](block), value isa VariableType ? value^2 : value) elseif key == :slice_thickness && :inverse_slice_thickness in keys(to_funcs) - apply_simple_constraint!(model, to_funcs[:inverse_slice_thickness](block), invert_value(value)) + apply_simple_constraint!(to_funcs[:inverse_slice_thickness](block), invert_value(value)) elseif key == :bandwidth && :inverse_bandwidth in keys(to_funcs) - apply_simple_constraint!(model, to_funcs[:inverse_bandwidth](block), invert_value(value)) + apply_simple_constraint!(to_funcs[:inverse_bandwidth](block), invert_value(value)) else error("Trying to set an unrecognised variable $key.") end @@ -234,7 +234,7 @@ function set_simple_constraints!(model::Model, block::BuildingBlock, kwargs) end """ - apply_simple_constraint!(model, variable, value) + apply_simple_constraint!([model, ]variable, value) Add a single constraint or objective to the JuMP `model`. This is an internal function used by [`set_simple_constraints`](@ref). @@ -253,9 +253,7 @@ apply_simple_constraint!(model::Model, variable::AbstractVector, value::Abstract Matches the listed variables between two [`BuildingBlock`](@ref) objects. By default all shared variables (i.e., those with the same name) are matched. """ -function match_blocks!(block1::BuildingBlock, block2::BuildingBlock, property_list) - model = owner_model(block1) - @assert model == owner_model(block2) +function match_blocks!(model::Model, block1::BuildingBlock, block2::BuildingBlock, property_list) for fn in property_list @constraint model fn(block1) == fn(block2) end @@ -266,29 +264,4 @@ function match_blocks!(block1::BuildingBlock, block2::BuildingBlock) match_blocks!(block1, block2, property_list) end - -optimize!(bb::BuildingBlock) = optimize!(owner_model(bb)) -function owner_model(bb::BuildingBlock) - if hasproperty(bb, :model) - return bb.model - else - for name in propertynames(bb) - value = getproperty(bb, name) - if value isa AbstractJuMPScalar - return owner_model(value) - end - end - end - error("Cannot find owner model") -end - -function has_values(bb::BuildingBlock) - try - return has_values(owner_model(bb)) - catch - # return true for building blocks without a model - return true - end -end - end \ No newline at end of file diff --git a/src/gradients/instant_gradients.jl b/src/gradients/instant_gradients.jl index 1586625..8658b38 100644 --- a/src/gradients/instant_gradients.jl +++ b/src/gradients/instant_gradients.jl @@ -1,9 +1,9 @@ module InstantGradients import StaticArrays: SVector -import JuMP: @constraint, @variable, Model, owner_model, AbstractJuMPScalar +import JuMP: @constraint, @variable, AbstractJuMPScalar import ...Variables: qvec, bmat_gradient, duration, variables, get_free_variable, VariableType import ...BuildingBlocks: GradientBlock, fixed -import ...BuildSequences: @global_model_constructor +import ...BuildSequences: global_model """ InstantGradientBlock(; orientation=nothing, qval=nothing, qvec=nothing, rotate=nothing, scale=nothing) @@ -20,14 +20,11 @@ Defines an instantaneous gradient. - [`qval`](@ref): Spatial scale on which spins will be dephased due to this pulsed gradient in rad/um. """ struct InstantGradientBlock <: GradientBlock - model::Model qvec :: SVector{3, VariableType} rotate :: Union{Nothing, Symbol} scale :: Union{Nothing, Symbol} end -@global_model_constructor InstantGradientBlock - function interpret_orientation(orientation, qval, qvec, will_rotate) if !isnothing(qvec) return (false, qvec) @@ -60,10 +57,10 @@ function interpret_orientation(orientation, qval, qvec, will_rotate) end end -function InstantGradientBlock(model::Model; orientation=nothing, qval=nothing, qvec=nothing, rotate=nothing, scale=nothing) +function InstantGradientBlock(; orientation=nothing, qval=nothing, qvec=nothing, rotate=nothing, scale=nothing) + model = global_model() (used_qval, qvec) = interpret_orientation(orientation, qval, qvec, !isnothing(rotate)) res = InstantGradientBlock( - model, qvec, rotate, scale diff --git a/src/helper_functions.jl b/src/helper_functions.jl index a170638..9be264f 100644 --- a/src/helper_functions.jl +++ b/src/helper_functions.jl @@ -1,6 +1,6 @@ module HelperFunctions import JuMP: @constraint -import ..BuildSequences: get_global_model +import ..BuildSequences: global_model import ..Sequences: Sequence import ..Pulses: SincPulse, ConstantPulse, InstantRFPulseBlock import ..Overlapping: TrapezoidGradient, SpoiltSliceSelect @@ -65,7 +65,7 @@ function excitation_pulse(; flip_angle=90, phase=0., frequency=0., shape=:sinc, TrapezoidGradient(orientation=[0, 0, -1.], rotate=rotate_grad, duration=:min); TR=Inf ) - @constraint get_global_model() qvec(grad, 1, nothing)[3] == -qvec(seq[2])[3] + @constraint global_model() qvec(grad, 1, nothing)[3] == -qvec(seq[2])[3] return seq end @@ -112,7 +112,7 @@ function refocus_pulse(; flip_angle=180, phase=0., frequency=0., shape=:sinc, sl orientation=isnothing(rotate_grad) ? [1, 1, 1] : [0, 0, 1] if isinf(slice_thickness) grad = TrapezoidGradient(orientation=orientation, duration=:min, rotate=rotate_grad) - @constraint get_global_model() qvec(grad)[3] == 2π * 1e-3 / spoiler + @constraint global_model() qvec(grad)[3] == 2π * 1e-3 / spoiler return Sequence(grad, pulse, grad; TR=Inf) else res = SpoiltSliceSelect(pulse; orientation=orientation, duration=:min, rotate=rotate_grad, slice_thickness=slice_thickness, spoiler_scale=spoiler) diff --git a/src/overlapping/spoilt_slice_selects.jl b/src/overlapping/spoilt_slice_selects.jl index 1cbe4a5..e218625 100644 --- a/src/overlapping/spoilt_slice_selects.jl +++ b/src/overlapping/spoilt_slice_selects.jl @@ -4,7 +4,7 @@ import LinearAlgebra: norm import StaticArrays: SVector import JuMP: @variable, @constraint, @objective, objective_function import ...BuildingBlocks: RFPulseBlock, set_simple_constraints! -import ...BuildSequences: GLOBAL_SCANNER, get_global_model +import ...BuildSequences: global_model, global_scanner import ...Variables: VariableType, variables, duration, rise_time, flat_time, effective_time, qvec, gradient_strength, slew_rate import ...Gradients: ChangingGradientBlock, ConstantGradientBlock import ..Abstract: interruptions, waveform, AbstractOverlapping @@ -37,7 +37,7 @@ struct SpoiltSliceSelect <: AbstractOverlapping end function SpoiltSliceSelect(pulse; orientation=[0, 0, 1], rotate=nothing, spoiler_scale=nothing, kwargs...) - model = get_global_model() + model = global_model() res = SpoiltSliceSelect( orientation ./ norm(orientation), @variable(model, start=0.1), @@ -47,7 +47,7 @@ function SpoiltSliceSelect(pulse; orientation=[0, 0, 1], rotate=nothing, spoiler @variable(model, start=0.1), @variable(model, start=0.1), rotate, - slew_rate(GLOBAL_SCANNER[]) + slew_rate(global_scanner()) ) for time_var in (res.rise_time1, res.flat_time1, res.diff_time, res.flat_time2, res.fall_time2) @constraint model time_var >= 0 @@ -68,7 +68,7 @@ function SpoiltSliceSelect(pulse; orientation=[0, 0, 1], rotate=nothing, spoiler end set_simple_constraints!(model, res, kwargs) - max_time = gradient_strength(GLOBAL_SCANNER[]) / slew_rate(GLOBAL_SCANNER[]) + max_time = gradient_strength(global_scanner()) / res.slew_rate @constraint model rise_time(res)[1] <= max_time @constraint model fall_time(res)[2] <= max_time return res diff --git a/src/overlapping/trapezoid_gradients.jl b/src/overlapping/trapezoid_gradients.jl index b199e49..5f86100 100644 --- a/src/overlapping/trapezoid_gradients.jl +++ b/src/overlapping/trapezoid_gradients.jl @@ -3,12 +3,12 @@ Defines a set of different options for MRI gradients. """ module TrapezoidGradients -import JuMP: @constraint, @variable, Model, VariableRef, owner_model, value +import JuMP: @constraint, @variable, VariableRef, value import StaticArrays: SVector import LinearAlgebra: norm import ...Variables: qvec, rise_time, flat_time, slew_rate, gradient_strength, variables, duration, δ, get_free_variable, VariableType, inverse_slice_thickness, inverse_bandwidth, effective_time import ...BuildingBlocks: duration, set_simple_constraints!, fixed, RFPulseBlock -import ...BuildSequences: @global_model_constructor +import ...BuildSequences: global_model import ...Gradients: ChangingGradientBlock, ConstantGradientBlock import ...Scanners: scanner_constraints! import ..Abstract: interruptions, waveform, AbstractOverlapping @@ -43,7 +43,6 @@ Any variables defined for the specific pulse added. Also: The [`bvalue`](@ref) can be constrained for multiple gradient pulses. """ struct TrapezoidGradient <: AbstractOverlapping - model :: Model rise_time :: VariableType flat_time :: VariableType slew_rate :: SVector{3, VariableType} @@ -54,21 +53,19 @@ struct TrapezoidGradient <: AbstractOverlapping time_after_pulse :: VariableType end -@global_model_constructor TrapezoidGradient - -function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing, flat_time=nothing, rotate=nothing, scale=nothing, pulse=nothing,kwargs...) +function TrapezoidGradient(orientation=nothing, rise_time=nothing, flat_time=nothing, rotate=nothing, scale=nothing, pulse=nothing,kwargs...) if isnothing(orientation) && isnothing(rotate) rate_1d = nothing slew_rate = ( - get_free_variable(model, nothing), - get_free_variable(model, nothing), - get_free_variable(model, nothing), + get_free_variable(nothing), + get_free_variable(nothing), + get_free_variable(nothing), ) else - rate_1d = get_free_variable(model, nothing) - @constraint model rate_1d >= 0 + rate_1d = get_free_variable(nothing) + @constraint global_model() rate_1d >= 0 if isnothing(orientation) - rate_1d = get_free_variable(model, nothing) + rate_1d = get_free_variable(nothing) slew_rate = ( rate_1d, 0., @@ -86,9 +83,9 @@ function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing, end end slew_rate = SVector{3}(slew_rate) - rise_time = get_free_variable(model, rise_time) + rise_time = get_free_variable(rise_time) if isnothing(pulse) - flat_time = get_free_variable(model, flat_time) + flat_time = get_free_variable(flat_time) time_before_pulse = time_after_pulse = 0. elseif pulse isa RFPulseBlock flat_time = duration(pulse) @@ -99,7 +96,6 @@ function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing, end res = TrapezoidGradient( - model, rise_time, flat_time, slew_rate, @@ -110,9 +106,9 @@ function TrapezoidGradient(model::Model; orientation=nothing, rise_time=nothing, time_after_pulse ) - set_simple_constraints!(model, res, kwargs) - @constraint model flat_time >= 0 - @constraint model rise_time >= 0 + set_simple_constraints!(res, kwargs) + @constraint flat_time >= 0 + @constraint rise_time >= 0 scanner_constraints!(res) return res end diff --git a/src/pulses/constant_pulses.jl b/src/pulses/constant_pulses.jl index 23e73b9..cf069cb 100644 --- a/src/pulses/constant_pulses.jl +++ b/src/pulses/constant_pulses.jl @@ -1,8 +1,8 @@ module ConstantPulses -import JuMP: VariableRef, @constraint, @variable, value, Model +import JuMP: VariableRef, @constraint, @variable, value import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!, fixed import ...Variables: variables, get_free_variable, flip_angle, phase, amplitude, frequency, start_time, end_time, VariableType, duration, effective_time, inverse_bandwidth -import ...BuildSequences: @global_model_constructor +import ...BuildSequences: global_model import ..FixedPulses: FixedPulse """ @@ -18,7 +18,6 @@ Represents an radio-frequency pulse with a constant amplitude and frequency (i.e - [`frequency`](@ref): frequency of the RF pulse relative to the Larmor frequency (in kHz). """ struct ConstantPulse <: RFPulseBlock - model :: Model amplitude :: VariableType duration :: VariableType phase :: VariableType @@ -26,17 +25,14 @@ struct ConstantPulse <: RFPulseBlock scale :: Union{Nothing, Symbol} end -@global_model_constructor ConstantPulse - -function ConstantPulse(model::Model; amplitude=nothing, duration=nothing, phase=nothing, frequency=nothing, scale=nothing, kwargs...) +function ConstantPulse(amplitude=nothing, duration=nothing, phase=nothing, frequency=nothing, scale=nothing, kwargs...) res = ConstantPulse( - model, - [get_free_variable(model, value) for value in (amplitude, duration, phase, frequency)]..., + [get_free_variable(value) for value in (amplitude, duration, phase, frequency)]..., scale ) - @constraint model res.amplitude >= 0 - set_simple_constraints!(model, res, kwargs) + @constraint global_model() res.amplitude >= 0 + set_simple_constraints!(res, kwargs) return res end diff --git a/src/pulses/instant_pulses.jl b/src/pulses/instant_pulses.jl index fb49d0e..7234ce0 100644 --- a/src/pulses/instant_pulses.jl +++ b/src/pulses/instant_pulses.jl @@ -1,27 +1,23 @@ module InstantPulses -import JuMP: @constraint, @variable, VariableRef, value, Model +import JuMP: @constraint, @variable, VariableRef, value import ...BuildingBlocks: RFPulseBlock, fixed import ...Variables: flip_angle, phase, start_time, variables, duration, get_free_variable, VariableType, effective_time, inverse_bandwidth -import ...BuildSequences: @global_model_constructor +import ...BuildSequences: global_model import ..FixedPulses: FixedInstantPulse struct InstantRFPulseBlock <: RFPulseBlock - model :: Model flip_angle :: VariableType phase :: VariableType scale :: Union{Nothing, Symbol} end -@global_model_constructor InstantRFPulseBlock - -function InstantRFPulseBlock(model::Model; flip_angle=nothing, phase=nothing, scale=nothing) +function InstantRFPulseBlock(flip_angle=nothing, phase=nothing, scale=nothing) res = InstantRFPulseBlock( - model, - get_free_variable(model, flip_angle), - get_free_variable(model, phase), + get_free_variable(flip_angle), + get_free_variable(phase), scale ) - @constraint model res.flip_angle >= 0 + @constraint global_model() res.flip_angle >= 0 return res end diff --git a/src/pulses/sinc_pulses.jl b/src/pulses/sinc_pulses.jl index ecdcfe8..35db4b2 100644 --- a/src/pulses/sinc_pulses.jl +++ b/src/pulses/sinc_pulses.jl @@ -1,11 +1,11 @@ module SincPulses -import JuMP: VariableRef, @constraint, @variable, value, Model +import JuMP: VariableRef, @constraint, @variable, value import QuadGK: quadgk import Polynomials: fit, Polynomial import ...BuildingBlocks: RFPulseBlock, set_simple_constraints!, fixed import ...Variables: flip_angle, phase, amplitude, frequency, VariableType, variables, get_free_variable, duration, effective_time, inverse_bandwidth -import ...BuildSequences: @global_model_constructor +import ...BuildSequences: global_model import ..FixedPulses: FixedPulse """ @@ -29,7 +29,6 @@ Represents an radio-frequency pulse with a constant amplitude and frequency. - [`bandwidth`](@ref): width of the rectangular function in frequency space (in kHz). If the `duration` is short (compared with 1/`bandwidth`), this bandwidth will only be approximate. """ struct SincPulse <: RFPulseBlock - model :: Model symmetric :: Bool apodise :: Bool nlobe_integral :: Polynomial @@ -42,31 +41,29 @@ struct SincPulse <: RFPulseBlock scale :: Union{Nothing, Symbol} end -@global_model_constructor SincPulse - -function SincPulse(model::Model; +function SincPulse(; symmetric=true, max_Nlobes=nothing, apodise=true, N_lobes=nothing, N_left=nothing, N_right=nothing, amplitude=nothing, phase=nothing, frequency=nothing, lobe_duration=nothing, scale=nothing, kwargs... ) if symmetric - N_lobes = get_free_variable(model, N_lobes) + N_lobes = get_free_variable(N_lobes) @assert isnothing(N_left) && isnothing(N_right) "N_left and N_right cannot be set if symmetric=true (default)" N_left_var = N_right_var = N_lobes else @assert isnothing(N_lobes) "N_lobes cannot be set if symmetric=true (default)" - N_left_var = get_free_variable(model, N_left) - N_right_var = get_free_variable(model, N_right) + N_left_var = get_free_variable(N_left) + N_right_var = get_free_variable(N_right) end res = SincPulse( - model, symmetric, apodise, nlobe_integral_params(max_Nlobes, apodise), N_left_var, N_right_var, - [get_free_variable(model, value) for value in (amplitude, phase, frequency, lobe_duration)]..., + [get_free_variable(value) for value in (amplitude, phase, frequency, lobe_duration)]..., scale ) + model = global_model() @constraint model res.amplitude >= 0 @constraint model res.N_left >= 1 if !symmetric diff --git a/src/scanners.jl b/src/scanners.jl index b5ddf9a..6c026a8 100644 --- a/src/scanners.jl +++ b/src/scanners.jl @@ -2,7 +2,7 @@ Define general [`Scanner`](@ref) type and methods as well as some concrete scanners. """ module Scanners -import JuMP: Model, @constraint, owner_model +import JuMP: Model, @constraint import ..Variables: gradient_strength, slew_rate import ..BuildingBlocks: BuildingBlock, get_children_blocks, ContainerBlock import ..Variables: variables @@ -89,18 +89,14 @@ predefined_scanners = Dict( ) """ - scanner_constraints!([model, ]building_block[, scanner]) + scanner_constraints!(building_block[, scanner]) Adds any constraints from a specific scanner to a [`BuildingBlock`]{@ref}. """ function scanner_constraints!(building_block::BuildingBlock, scanner::Scanner) - scanner_constraints!(owner_model(building_block), building_block, scanner) -end - -function scanner_constraints!(model::Model, building_block::BuildingBlock, scanner::Scanner) for func in [gradient_strength, slew_rate] if isfinite(func(scanner)) - scanner_constraints!(model, building_block, scanner, func) + scanner_constraints!(building_block, scanner, func) end end end diff --git a/src/sequences.jl b/src/sequences.jl index e6f2947..bd5910c 100644 --- a/src/sequences.jl +++ b/src/sequences.jl @@ -3,8 +3,8 @@ Define the [`Sequence`](@ref) building block. """ module Sequences import Printf: @sprintf -import JuMP: Model, @constraint -import ...BuildSequences: @global_model_constructor +import JuMP: @constraint +import ...BuildSequences: global_model import ...Variables: variables, start_time, duration, VariableType, get_free_variable, TR, end_time import ...BuildingBlocks: BuildingBlock, ContainerBlock, to_block, get_children_indices, fixed, BuildingBlockPrinter, _robust_value, get_children_blocks @@ -21,25 +21,21 @@ or be embedded as a [`BuildingBlock`](@ref) into higher-order `Sequence` or othe - [`TR`](@ref): repetition time of sequence in ms. """ struct Sequence <: ContainerBlock - model :: Model _blocks :: Vector{<:BuildingBlock} TR :: VariableType - function Sequence(model::Model, blocks::AbstractVector; TR=nothing) + function Sequence(blocks::AbstractVector; TR=nothing) seq = new( - model, to_block.(blocks), - get_free_variable(model, TR), + get_free_variable(TR), ) if !(TR isa Number && isinf(TR)) - @constraint model seq.TR >= duration(seq) + @constraint global_model() seq.TR >= duration(seq) end return seq end end -@global_model_constructor Sequence - -Sequence(model::Model, blocks...; TR=nothing) = Sequence(model, [blocks...]; TR=TR) +Sequence(blocks...; TR=nothing) = Sequence([blocks...]; TR=TR) Base.length(seq::Sequence) = length(seq._blocks) Base.getindex(seq::Sequence, index) = seq._blocks[index] diff --git a/src/variables.jl b/src/variables.jl index 01afc14..d1b7f0d 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -1,5 +1,6 @@ module Variables -import JuMP: @variable, Model, @objective, objective_function, owner_model, has_values, value, AbstractJuMPScalar +import JuMP: @variable, Model, @objective, objective_function, value, AbstractJuMPScalar +import .. all_variables_symbols = [ :block => [ @@ -82,29 +83,23 @@ const VariableType = Union{Number, AbstractJuMPScalar} """ - get_free_variable(model, value; integer=false) + get_free_variable(value; integer=false) Get a representation of a given `variable` given a user-defined constraint. """ -get_free_variable(::Model, value::Number; integer=false) = integer ? Int(value) : Float64(value) -function get_free_variable(model::Model, value::VariableType; integer=false) - if owner_model(value) != model - if has_values(value) - return value(value) - end - error("Cannot set any constraints between sequences stored in different JuMP models.") - end - return value -end -get_free_variable(model::Model, ::Nothing; integer=false) = @variable(model, start=0.01, integer=integer) -get_free_variable(model::Model, value::Symbol; integer=false) = integer ? error("Cannot maximise or minimise an integer variable") : get_free_variable(model, Val(value)) -function get_free_variable(model::Model, ::Val{:min}) - var = get_free_variable(model, nothing) +get_free_variable(value::Number; integer=false) = integer ? Int(value) : Float64(value) +get_free_variable(value::VariableType; integer=false) = value +get_free_variable(::Nothing; integer=false) = @variable(global_model(), start=0.01, integer=integer) +get_free_variable(value::Symbol; integer=false) = integer ? error("Cannot maximise or minimise an integer variable") : get_free_variable(Val(value)) +function get_free_variable(::Val{:min}) + var = get_free_variable(nothing) + model = global_model() @objective model Min objective_function(model) + var return var end -function get_free_variable(model::Model, ::Val{:max}) - var = get_free_variable(model, nothing) +function get_free_variable(::Val{:max}) + var = get_free_variable(nothing) + model = global_model() @objective model Min objective_function(model) - var return var end -- GitLab