diff --git a/src/components/gradient_waveforms/no_gradient_blocks.jl b/src/components/gradient_waveforms/no_gradient_blocks.jl index d6b5dae24c68b676dad7e3f6b12a9330d445605b..7b25feb159dde2b6d6f2f6c6ba05f756cc9351b6 100644 --- a/src/components/gradient_waveforms/no_gradient_blocks.jl +++ b/src/components/gradient_waveforms/no_gradient_blocks.jl @@ -16,7 +16,7 @@ struct NoGradient <: GradientWaveform{0} end for func in (:qvec, :gradient_strength, :slew_rate) - @eval variables.$(func).f(::NoGradient) = zero(SVector{3, Float64}) + @eval variables.$(func)(::NoGradient) = zero(SVector{3, Float64}) end @defvar begin diff --git a/src/components/pulses/composite_pulses.jl b/src/components/pulses/composite_pulses.jl index f32c105093e4939eec62b54c21af66af5f690f7e..87cca1813d2aae597e852576f6093ad809b49480 100644 --- a/src/components/pulses/composite_pulses.jl +++ b/src/components/pulses/composite_pulses.jl @@ -87,7 +87,7 @@ for (fn, default) in [ (:phase, NaN), (:frequency, NaN), ] - @eval function variables.$fn.f(pulse::CompositePulse, time::Number) + @eval function variables.$fn(pulse::CompositePulse, time::Number) (index, rtime) = get_pulse_index(pulse, time) if ( index < 1 || index > length(pulse) || diff --git a/src/components/pulses/generic_pulses.jl b/src/components/pulses/generic_pulses.jl index af88e4f7ca5dde134abe4957d915698443be5138..f75b17d7ebf1e46a774881575e754d446d5cda00 100644 --- a/src/components/pulses/generic_pulses.jl +++ b/src/components/pulses/generic_pulses.jl @@ -83,7 +83,7 @@ GenericPulse(pulse::RFPulseComponent, t1::Number, t2::Number) = GenericPulse(mak end for fn in (:amplitude, :phase) - @eval function variables.$fn.f(fp::GenericPulse, time::Number) + @eval function variables.$fn(fp::GenericPulse, time::Number) i2 = findfirst(t->t > time, fp.time) if isnothing(i2) @assert time ≈ fp.time[end] diff --git a/src/containers/base_sequences.jl b/src/containers/base_sequences.jl index 47d9f81dd0c0a3f9aee782ad9dbbce9f896451b1..91945097eee2967d9808966067aaf0cf7dcf34f9 100644 --- a/src/containers/base_sequences.jl +++ b/src/containers/base_sequences.jl @@ -122,9 +122,9 @@ function edge_times(seq::BaseSequence; tol=1e-6) end for fn in (:gradient_strength, :amplitude, :phase, :frequency) - @eval function variables.$fn.f(sequence::BaseSequence, time::AbstractFloat) + @eval function variables.$fn(sequence::BaseSequence, time::AbstractFloat) (block_time, block) = sequence(time) - return variables.$fn.f(block, block_time) + return variables.$fn(block, block_time) end end diff --git a/src/containers/building_blocks.jl b/src/containers/building_blocks.jl index 4e499ac2d1a854fca3a0c63c1e0a7884e875ebbb..bac1a8205048f1b704aff00ec3f28879aa7b7d7d 100644 --- a/src/containers/building_blocks.jl +++ b/src/containers/building_blocks.jl @@ -273,12 +273,12 @@ function get_pulse(bb::BaseBuildingBlock, time::Number) end for (fn, default_value) in ((:amplitude, 0.), (:phase, NaN), (:frequency, NaN)) - @eval function variables.$fn.f(bb::BaseBuildingBlock, time::Number) + @eval function variables.$fn(bb::BaseBuildingBlock, time::Number) pulse = get_pulse(bb, time) if isnothing(pulse) return $default_value end - return variables.$fn.f(pulse[1], pulse[2]) + return variables.$fn(pulse[1], pulse[2]) end end diff --git a/src/parts/trapezoids.jl b/src/parts/trapezoids.jl index 084accce7285f0b85656a0ceb4bbbb9926262412..8f8506edf32906772d3f9eacfe850a308dafea9f 100644 --- a/src/parts/trapezoids.jl +++ b/src/parts/trapezoids.jl @@ -208,7 +208,7 @@ end Base.keys(::SliceSelect) = (Val(:rise), Val(:flat), Val(:pulse), Val(:fall)) Base.getindex(pg::SliceSelect, ::Val{:pulse}) = (0., pg.pulse) -@defvar pulse inverse_slice_thickness(ss::SliceSelect) = 1e3 * variables.gradient_strength_norm(ss.trapezoid) .* variables.inverse_bandwidth(ss.pulse) +@defvar pulse inverse_slice_thickness(ss::SliceSelect) = 1e3 * variables.gradient_strength_norm(ss.trapezoid) * variables.inverse_bandwidth(ss.pulse) """ slice_thickness(slice_select) diff --git a/src/plot.jl b/src/plot.jl index 4132c370cd6dc8751159bd9df0c4dc85a5349c5a..be7c2c92183f5010245453bea309f4f8bd1183a0 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -244,7 +244,7 @@ This function will only work if [`Makie`](https://makie.org) is installed and im - `font` sets whether the rendered text is :regular, :bold, or :italic. - `fontsize`: set the size of each character. -$(Base.Docs.doc(generic_plot_attributes!)) +$(string(Base.Docs.@doc(generic_plot_attributes!))) """ function plot_sequence end diff --git a/src/printing.jl b/src/printing.jl index 4d67c81a3617c31fafb41e1a4a6816303de720ad..c5a79e76e58b4f5c1ab69d30478d156b496bbe2b 100644 --- a/src/printing.jl +++ b/src/printing.jl @@ -1,7 +1,7 @@ module Printing import JuMP: value, AbstractJuMPScalar import Printf: @sprintf -import ..Variables: VariableType, variables, AbstractBlock, Variable +import ..Variables: VariableType, variables, AbstractBlock import ..Containers: BuildingBlock, waveform, events, start_time function _robust_value(possible_number::AbstractJuMPScalar) diff --git a/src/variables.jl b/src/variables.jl index b853849a89a50f6521891e15d8360b6f4aa73e21..9f6f0f1b8e716830361ec966172aa4ace88ed174 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -59,32 +59,6 @@ adjust_groups(::AbstractBlock) = Symbol[] function adjust end -abstract type AnyVariable end - -""" -A sequence property that can be constrained and/or optimised. - -It acts as a function, so you can call it on a sequence or building block to get the actual values (e.g., `v(sequence)`). -It can return one of the following: -- a number -- a vector of number -- a NamedTuple with the values for individual sequence components -""" -mutable struct Variable <: AnyVariable - name :: Symbol - f :: Function - getter :: Union{Nothing, Function} -end - -struct AlternateVariable <: AnyVariable - name :: Symbol - other_var :: Symbol - from_other :: Function - to_other :: Union{Nothing, Function} - inverse :: Bool -end - - """ base_variables([T]) @@ -95,19 +69,30 @@ This only returns those [`Variable`](@ref) directly defined for this component/b If `T` is not provided, all [`Variable`](@ref) objects are returned. """ function base_variables() - all_members = (s => getproperty(variables, s) for s in Base.unsorted_names(variables, all=true)) - return Dict{Symbol, Variable}( - s => v for (s, v) in all_members if v isa Variable + return Dict{Symbol, Function}( + s => getproperty(variables, s) for s in names(variables) if s != :variables ) end function base_variables(T::Type{<:AbstractBlock}) return Dict( - s => v for (s, v) in base_variables() if hasmethod(v.f, (T, )) + s => v for (s, v) in base_variables() if which(v, (T, )) !== default_generic_method[v] ) end +""" + variable_defined_for(variable, args...; kwargs...) +""" +function variable_defined_for(name::Symbol, args...; kwargs...) + variable_defined_for(getproperty(variables, name), args...; kwargs...) +end + +function variable_defined_for(func::Function, args...; kwargs...) + return which(func, typeof.(args)) !== default_generic_method[func] +end + + # Add Variables to the individual sequence components/blocks properties function Base.propertynames(::T) where {T <: AbstractBlock} f = Base.fieldnames(T) @@ -121,9 +106,7 @@ function Base.getproperty(block::T, v::Symbol) where T <: AbstractBlock end if v in Base.unsorted_names(variables; all=true) var = getproperty(variables, v) - if hasmethod(var.f, (T, )) - return var.f(block) - end + return var(block) end error("Type $(T) has no field or variable $(v)") end @@ -140,12 +123,6 @@ function Base.setproperty!(block::T, v::Symbol, value) where T <: AbstractBlock error("Type $(T) has no field or variable $(v)") end -""" - variable_defined_for(var, Val(type)) - -Check whether variable is defined for a specific sub-type. -""" -variable_defined_for(var::Variable, ::Val{T}) where {T <: AbstractBlock} = hasmethod(var.f, (T, )) """ Main module containing all the MRIBuilder sequence variables. @@ -161,12 +138,126 @@ After sequence generation you can get the variable values by calling `variables.echo_time(seq)`. For the sequence defined above this would return 70. (or a number very close to that). """ -baremodule variables +baremodule variables end + + +""" +Contains for each variable the default, generic method. + +This is the method that checks for [`alternative_functions`](@ref) or uses one of the getters. +""" +const default_generic_method = IdDict{Function, Method}() + +""" +Mapping of variable names to alternative ways to compute that variables. +""" +const alternative_variables = Dict{Symbol, Vector{Tuple{Symbol, Function, Bool}}}() + +""" +Raised if there is no way to reach a valid function by the default, generic method ([`default_generic_method`](@ref)). +""" +struct InvalidRoute <: Exception end + +""" +Assigns getters to specific variables. +""" +const getters = Dict{Symbol, Function}() + +""" + _get_variable(name, tried_names, args...; kwargs...) + +Tries to find a route to get values for the variable `name` with the given `args` and `kwargs`. + +The route through `tried_names` has already been attempted. + +This function returns the route to get to the value and the value itself. +""" +function _get_variable(name::Symbol, tried_names::Set{Symbol}, args...; kwargs...) + if hasproperty(variables, name) + func = getproperty(variables, name) + if variable_defined_for(func, args...; kwargs...) + return Any[], func(args...; kwargs...) + end + end + + # alternative functions + if name in keys(alternative_variables) + for (new_name, converter, _) in alternative_variables[name] + if !(new_name in tried_names) + try + route, value = _get_variable(new_name, union(tried_names, [name]), args...; kwargs...) + pushfirst!(route, (:alternative, new_name)) + return route, apply_converter(converter, value) + catch e + if e isa InvalidRoute + continue + else + rethrow() + end + end + end + end + end + + if name in keys(getters) + # getters + first, others = args[1], args[2:end] + getter = getters[name] + route, value = _get_mult_variable(name, getter(first), others...; kwargs...) + pushfirst!(route, (:getter, getter)) + return route, value + end + + throw(InvalidRoute()) +end + +apply_converter(converter::Function, value) = converter(value) +apply_converter(converter::Function, value::NamedTuple) = NamedTuple(k=>converter(v) for (k, v) in pairs(value)) +apply_converter(converter::Function, value::Tuple) = Tuple(converter(v) for v in value) + +""" +Helper to call the variable for a result of a getter. Used in [`_get_variable`](@ref). +""" +_get_mult_variable(name, t, args...; kwargs...) = _get_variable(name, Set{Symbol}(), t, args...; kwargs...) +_get_mult_variable(name, t::NamedTuple, args...; kwargs...) = (Any[], NamedTuple( + key => _get_mult_variable(name, var, args...; kwargs...)[2] for (key, var) in pairs(t) +)) +_get_mult_variable(name, t::AbstractVector, args...; kwargs...) = (Any[], [ + _get_mult_variable(name, var, args...; kwargs...)[2] for var in t +]) +_get_mult_variable(name, t::Tuple, args...; kwargs...) = (Any[], Tuple( + _get_mult_variable(name, var, args...; kwargs...)[2] for var in t +)) + +""" + add_new_variable!(name) + +Adds a new variable in [`variables`](@ref). + +This is a helper function, which is called by `@defvar` for any variable that does not exist yet. +""" +function add_new_variable!(name::Symbol) + @eval variables function $(name) end + @eval variables $(Expr(:public, name)) + @eval function variables.$(name)(ab::AbstractBlock, args...; kwargs...) + try + return _get_variable($(QuoteNode(name)), Set{Symbol}(), ab, args...; kwargs...)[2] + catch e + if e isa InvalidRoute + string_name = String($(QuoteNode(name))) + error("Variable `" * string_name * "` is not defined for block of type `" * string(typeof(ab)) * ".") + end + rethrow() + end + end + + func = getproperty(variables, name) + default_generic_method[func] = methods(func)[1] end """ - @defvar([getter, ], function(s)) + @defvar(function(s)) Defines new [`variables`](@ref). @@ -187,59 +278,49 @@ Multiple variables can be defined in a single `@defvar` by including them in a c end end ``` - -Before the variable function definitions one can include a `getter`. -This `getter` defines the type of the sequence component for which the variables will be defined. -If the variable is not defined for the sequence, the variable will be extracted for those type of sequence components instead. -The following sequence component types are provided: -- `pulse`: use [`get_pulse`](@ref) -- `gradient`: use [`get_gradient`](@ref) -- `readout`: use [`get_readout`](@ref) -- `pathway`: use [`get_pathway`](@ref) -e.g. the following defines a `flip_angle` variable, which is marked as a property of an RF pulse. -```julia -@defvar pulse flip_angle(...) = ... -``` -If after this definition, `flip_angle` is not explicitly defined for any sequence, it will be extracted for the RF pulses in that sequence instead. """ -macro defvar(func_def) - return _defvar(func_def, nothing) +macro defvar(getter, func_def) + func_names, expr = _defvar(func_def) + additional = :(for name in $(func_names) + set_getter!(name, $(QuoteNode(getter))) + end) + return Expr( + :block, + expr.args..., + additional + ) end -macro defvar(getter, func_def) - return _defvar(func_def, getter) +macro defvar(func_def) + _, expr = _defvar(func_def) + return expr end -function _defvar(func_def, getter=nothing) +function _defvar(func_def) func_names = [] - - if getter isa Symbol - getter_dict = Dict( - :pulse => get_pulse, - :gradient => get_gradient, - :pathway => get_pathway, - :readout => get_readout, - ) - if !(getter in keys(getter_dict)) - error("label in `@defvar <label> <statement>` should be one of `pulse`/`gradient`/`pathway`/`readout`, not `$getter`") - end - getter = getter_dict[getter] - end - function adjust_function(ex) - if ex isa Expr && ex.head == :block - return Expr(:block, adjust_function.(ex.args)...) - end - if ex isa Expr && ex.head == :function && length(ex.args) == 1 - push!(func_names, ex.args[1]) - return :nothing + if ex isa Expr + if ex.head == :block + return Expr(:block, adjust_function.(ex.args)...) + end + if ex.head == :function && length(ex.args) == 1 + push!(func_names, ex.args[1]) + return :nothing + end + if ex.head == :macrocall && ex.args[1] == GlobalRef(Core, Symbol("@doc")) + new_expr = macroexpand(variables, ex) + func_def, add_docs = new_expr.args + fixed_func_def = adjust_function(func_def) + add_docs.args[end] = esc(add_docs.args[end]) + return Expr(:block, fixed_func_def, add_docs) + end end try fn_def = MacroTools.splitdef(ex) push!(func_names, fn_def[:name]) new_def = Dict{Symbol, Any}() - new_def[:name] = Expr(:., Expr(:., :variables, QuoteNode(fn_def[:name])), QuoteNode(:f)) + new_def[:name] = Expr(:., :variables, QuoteNode(fn_def[:name])) new_def[:args] = esc.(fn_def[:args]) new_def[:kwargs] = esc.(fn_def[:kwargs]) new_def[:body] = esc(fn_def[:body]) @@ -267,119 +348,184 @@ function _defvar(func_def, getter=nothing) for func_name in func_names push!(expressions, quote if !($(QuoteNode(func_name)) in names(variables; all=true)) - function $(func_name) end - variables.$(func_name) = Variable($(QuoteNode(func_name)), $(func_name), $getter) - end - if variables.$(func_name) isa AlternateVariable - error("$($(esc(func_name)).name) is defined through $(variables.$(func_name).other_var). Please define that variable instead.") - end - if !isnothing($getter) && variables.$(func_name).getter != $getter - if isnothing(variables.$(func_name).getter) - variables.$(func_name).getter = $getter - else - name = variables.$(func_name).name - error("$(name) is already defined as a variable for $(variables.$(func_name).getter). Cannot switch to $($getter).") - end + add_new_variable!($(QuoteNode(func_name))) end end ) end args = vcat([e.args for e in expressions]...) - return Expr( + return func_names, Expr( :block, args..., new_func_def ) end -@defvar function duration end -""" - duration(block) - -Duration of the sequence or building block in ms. -""" -variables.duration - +@defvar begin + """ + duration(block) -function def_alternate_variable!(name::Symbol, other_var::Symbol, from_other::Function, to_other::Union{Nothing, Function}, inverse::Bool) - setproperty!(variables, name, AlternateVariable(name, other_var, from_other, to_other, inverse)) + Duration of the sequence or building block in ms. + """ + function duration end end -def_alternate_variable!(:spoiler, :qval, q->1e-3 * 2π/q, l->1e-3 * 2π/l, true) -def_alternate_variable!(:qval, :qval_square, sqrt, q -> q * q, false) -def_alternate_variable!(:qval_square, :qvec, qv -> sum(q -> q * q, qv), nothing, false) """ - qval(gradient) + set_getter!(variable_name, getter) -The norm of the [`variables.qvec`](@ref). +Set the getter function for `variable_name`. + +If the value for `variable` is not defined for a sequence, the value for the result of the `getter` function is returned instead. + +Possible values for the `getter` function are: +- `:pulse`: [`get_pulse`](@ref) +- `:gradient`: [`get_gradient`](@ref) +- `:readout`: [`get_readout`](@ref) +- `:pathway`: [`get_pathway`](@ref) """ -variables.qval +set_getter!(name::Symbol, getter::Symbol) = set_getter!(name, getter_functions[getter]) + +function set_getter!(name::Symbol, getter::Function) + if !(name in keys(getters)) + getters[name] = getter + else + if getters[name] !== getter + error("") + end + end +end """ - spoiler(gradient) + add_alternative_variable!(name, other_func, conversion) -Spatial scale in mm over which the spoiler gradient will dephase by 2π. +Defines an alternative way to compute the variable with given `name`. -Automatically computed based on [`variables.qvec`](@ref). +If the variable `name` is not defined and `other_name` is, +then the value of `name` is computed by applying `conversion` to the value of `other_name`. """ -variables.spoiler +function add_alternative_variable!(name::Symbol, other_name::Symbol, conversion::Function, inverts::Bool) + if !(name in keys(alternative_variables)) + alternative_variables[name] = Tuple{Symbol, Function, Bool}[] + end + push!(alternative_variables[name], (other_name, conversion, inverts)) +end + +@defvar begin + """ + qval(gradient) + + The norm of the [`variables.qvec`](@ref). + """ + function qval end + + """ + qval_square(gradient) + + The square of the area under the curve ([`variables.qval`](@ref)). + + Constraining this can be more efficient than constraining [`qval`](@ref) as it avoids taking a square root. + """ + function qval_square end + + """ + spoiler(gradient) + + Spatial scale in mm over which the spoiler gradient will dephase by 2π. + + Automatically computed based on [`variables.qvec`](@ref). + """ + function spoiler end + + """ + qvec(gradient) + + The total integral of the area under the gradient curve as a length-3 vector. + + The norm of this vector is available as [`qval`](@ref). + """ + function qvec end +end + +add_alternative_variable!(:spoiler, :qval, q->1e-3 * 2π/q, true) +add_alternative_variable!(:qval, :spoiler, l->1e-3 * 2π/l, true) +add_alternative_variable!(:qval, :qval_square, sqrt, false) +add_alternative_variable!(:qval_square, :qval, q->q^2, false) +add_alternative_variable!(:qval_square, :qvec, qv -> sum(q -> q * q, qv), false) + for vec_variable in [:gradient_strength, :slew_rate] vec_square = Symbol(string(vec_variable) * "_square") vec_norm = Symbol(string(vec_variable) * "_norm") - def_alternate_variable!(vec_norm, vec_square, sqrt, v -> v * v, false) - def_alternate_variable!(vec_square, vec_variable, v -> v[1] * v[1] + v[2] * v[2] + v[3] * v[3], nothing, false) + add_alternative_variable!(vec_square, vec_variable, v -> v[1] * v[1] + v[2] * v[2] + v[3] * v[3], false) + add_alternative_variable!(vec_norm, vec_square, sqrt, false) + add_alternative_variable!(vec_square, vec_norm, v->v^2, false) end -""" - gradient_strength_norm(gradient) +@defvar begin + """ + gradient_strength_norm(gradient) -The norm of the [`variables.gradient_strength`](@ref). -""" -variables.gradient_strength_norm + The norm of the [`variables.gradient_strength`](@ref). + """ + function gradient_strength_norm end -""" - slew_rate_norm(gradient) + """ + slew_rate_norm(gradient) -The norm of the [`variables.slew_rate`](@ref). -""" -variables.slew_rate_norm + The norm of the [`variables.slew_rate`](@ref). + """ + function slew_rate_norm end +end -for name in [:slice_thickness, :bandwidth, :fov, :voxel_size] +for name in [:fov, :voxel_size] inv_name = Symbol("inverse_" * string(name)) - def_alternate_variable!(name, inv_name, inv, inv, true) + mult_inv(x) = inv.(x) + add_alternative_variable!(name, inv_name, mult_inv, true) + add_alternative_variable!(inv_name, name, mult_inv, true) + add_new_variable!(name) + add_new_variable!(inv_name) end -for (name, alt_name) in [ - (:TE, :echo_time), - (:TR, :repetition_time), - (:Δ, :diffusion_time), -] - def_alternate_variable!(name, alt_name, identity, identity, false) +for name in [:slice_thickness, :bandwidth] + inv_name = Symbol("inverse_" * string(name)) + add_alternative_variable!(name, inv_name, inv, true) + add_alternative_variable!(inv_name, name, inv, true) + add_new_variable!(name) + add_new_variable!(inv_name) end -""" - TE(sequence) -Alternative name to compute the [`variables.echo_time`](@ref) of a sequence in ms. -""" -variables.TE +@defvar begin + """ + echo_time(sequence) -""" - TR(sequence) + Computes the echo time(s) of a sequence in ms. + """ + function echo_time end -Alternative name to compute the [`variables.repetition_time`](@ref) of a sequence in ms. -""" -variables.TR + """ + repetition_time(sequence) -""" - Δ(sequence) + Computes the repetition_times of a sequence in ms. + """ + function repetition_time end -Alternative name to compute the [`variables.diffusion_time`](@ref) of a sequence in ms. -""" -variables.Δ + """ + diffusion_time(sequence) + + Computes the diffusion time of a sequence in ms. + """ + function diffusion_time end +end + +@eval variables begin + const TE = echo_time + const TR = repetition_time + const Δ = diffusion_time + $(Expr(:public, :TE, :TR, :Δ)) +end """ @@ -461,6 +607,17 @@ Any `pathway` variables not explicitly defined for this sequence will be passed """ function get_pathway end +""" +Mapping of symbols to actual getter functions. + +Used in [`set_getter!`](@ref). +""" +getter_functions = Dict( + :pulse => get_pulse, + :gradient => get_gradient, + :pathway => get_pathway, + :readout => get_readout, +) """ gradient_orientation(building_block) @@ -470,48 +627,6 @@ Returns the gradient orientation. function gradient_orientation end -function (var::Variable)(block::AbstractBlock, args...; kwargs...) - if !applicable(var.f, block, args...) && !isnothing(var.getter) - apply_to = var.getter(block) - if apply_to isa AbstractBlock - return var(apply_to, args...; kwargs...) - elseif apply_to isa NamedTuple - return NamedTuple(k => var(v, args...; kwargs...) for (k, v) in pairs(apply_to)) - elseif apply_to isa AbstractVector{<:AbstractBlock} || apply_to isa Tuple - return var.(apply_to, args...; kwargs...) - else - error("$(var.getter) returned an unexpected type: $(typeof(apply_to)).") - end - end - return var.f(block, args...; kwargs...) -end - -# Special case for BuildingBlock events -function (var::Variable)(event::Tuple{<:VariableType, <:AbstractBlock}, args...; kwargs...) - if applicable(var.f, event, args...; kwargs...) - return var.f(event, args...; kwargs...) - end - # falling back to just processing the `AbstractBlock` - return var(event[2], args...; kwargs...) -end - -function (var::AlternateVariable)(args...; kwargs...) - other_var = getproperty(variables, var.other_var) - apply_from_other(res::VariableType) = var.from_other(res) - function apply_from_other(res::AbstractVector{<:VariableType}) - try - return var.from_other(res) - catch e - if e isa MethodError - return var.from_other.(res) - end - end - end - apply_from_other(res::NamedTuple) = NamedTuple(k => apply_from_other(v) for (k, v) in pairs(res)) - return apply_from_other(other_var(args...; kwargs...)) -end - - """ add_cost_function!(function, level=2) @@ -545,24 +660,68 @@ function set_simple_constraints!(block::AbstractBlock, kwargs) real_kwargs = Dict(key => value for (key, value) in kwargs if !isnothing(value)) for (key, value) in real_kwargs - var = getproperty(variables, key) - if var isa AlternateVariable - if var.other_var in keys(real_kwargs) - error("Set constraints on both $key and $(var.other_var), however they are equivalent.") + apply_simple_constraint!(block, key, value) + end + nothing +end + +apply_simple_constraint!(block, variable::Symbol, value::Nothing) = nothing +apply_simple_constraint!(block::Union{NamedTuple, Tuple}, variable::Symbol, value::Nothing) = nothing + +function apply_simple_constraint!(block::NamedTuple, variable::Symbol, value::NamedTuple) + for (k, v) in pairs(value) + apply_simple_constraint!(getproperty(block, k), variable, v) + end +end +function apply_simple_constraint!(block::Tuple, variable::Symbol, value::Tuple) + @assert length(block) == length(value) + for (b, v) in zip(block, value) + apply_simple_constraint!(b, variable, v) + end +end +function apply_simple_constraint!(block::Union{NamedTuple, Tuple}, variable::Symbol, value) + for b in block + apply_simple_constraint!(b, variable, value) + end +end + +function apply_simple_constraint!(block, variable::Symbol, value, previous_values=Set{Symbol}()::Set{Symbol}) + route, to_set = try + _get_variable(variable, previous_values, block) + catch e + if e isa InvalidRoute + error("Variable `$(variable)` cannot be set to constrain a building block/sequence of type $(typeof(block)).") + end + rethrow() + end + if iszero(length(route)) + return apply_simple_constraint!(to_set, value) + else + (step_type, step) = route[1] + if step_type == :alternative + if !(step in keys(alternative_variables)) + return apply_simple_constraint!(to_set, value) + end + for (my_name, converter, invert) in alternative_variables[step] + if my_name == variable + return apply_simple_constraint!(block, step, _adjust_value(value, converter, invert), union(previous_values, [variable])) + end end - invert_value(value::VariableType) = var.to_other(value) - invert_value(value::Symbol) = invert_value(Val(value)) - invert_value(::Val{:min}) = var.inverse ? Val(:max) : Val(:min) - invert_value(::Val{:max}) = var.inverse ? Val(:min) : Val(:max) - invert_value(value::AbstractVector) = invert_value.(value) - apply_simple_constraint!(getproperty(variables, var.other_var)(block), invert_value(value)) + return apply_simple_constraint!(to_set, value) + elseif step_type == :getter + apply_simple_constraint!(step(block), variable, value) else - apply_simple_constraint!(var(block), value) + error() end end - nothing end +_adjust_value(value::Symbol, converter::Function, invert::Bool) = _adjust_value(Val(value), converter, invert) +_adjust_value(::Val{:min}, ::Function, invert::Bool) = invert ? Val(:max) : Val(:min) +_adjust_value(::Val{:max}, ::Function, invert::Bool) = invert ? Val(:min) : Val(:max) +_adjust_value(value::VariableType, converter::Function, ::Bool) = converter(value) +_adjust_value(value::AbstractArray, converter::Function, ::Bool) = converter(value) + """ apply_simple_constraint!(variable, value)