Skip to content
Snippets Groups Projects
Unverified Commit 092bff04 authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

Set getter for @defvar

Also, added some more error checking to @defvar
parent e980ac80
No related branches found
No related tags found
1 merge request!2Define variables through new @defvar macro
This commit is part of merge request !2. Comments created here will be created in the context of that merge request.
......@@ -22,8 +22,8 @@ export build_sequence, global_model, global_scanner, fixed
import .Scanners: Scanner, B0, Siemens_Connectom, Siemens_Prisma, Siemens_Terra, Default_Scanner
export Scanner, B0, Siemens_Connectom, Siemens_Prisma, Siemens_Terra, Default_Scanner
import .Variables: variables, duration, effective_time, flip_angle, amplitude, phase, frequency, bandwidth, N_left, N_right, qval, δ, rise_time, flat_time, slew_rate, gradient_strength, qvec, qval_square, slice_thickness, inverse_slice_thickness, fov, inverse_fov, voxel_size, inverse_voxel_size, resolution, nsamples, oversample, dwell_time, ramp_overlap, spoiler_scale, repetition_time, TR, Δ, get_gradient, get_pulse, get_readout, TE, echo_time, diffusion_time, make_generic, slew_rate3, gradient_strength3, qval3
export variables, duration, effective_time, flip_angle, amplitude, phase, frequency, bandwidth, N_left, N_right, qval, δ, rise_time, flat_time, slew_rate, gradient_strength, qvec, qval_square, slice_thickness, inversne_slice_thickness, fov, inverse_fov, voxel_size, inverse_voxel_size, resolution, nsamples, oversample, dwell_time, ramp_overlap, spoiler_scale, repetition_time, TR, Δ, get_gradient, get_pulse, get_readout, TE, echo_time, diffusion_time, make_generic, slew_rate3, gradient_strength3, qval3
import .Variables: variables, effective_time, make_generic, @defvar, duration
export variables, effective_time, make_generic, @defvar, duration
#import .Components: InstantPulse, ConstantPulse, SincPulse, GenericPulse, InstantGradient, SingleReadout, ADC, CompositePulse, edge_times
#export InstantPulse, ConstantPulse, SincPulse, GenericPulse, InstantGradient, SingleReadout, ADC, CompositePulse, edge_times
......
......@@ -65,11 +65,21 @@ It can return one of the following:
- a vector of number
- a NamedTuple with the values for individual sequence components
"""
struct Variable <: AnyVariable
mutable struct Variable <: AnyVariable
name :: Symbol
f :: Function
getter :: Union{Nothing, Function}
end
struct AlternateVariable <: AnyVariable
name :: Symbol
other_var :: Symbol
to_other :: Function
from_other :: Function
inverse :: Bool
end
"""
variable_defined_for(var, Val(type))
......@@ -77,12 +87,50 @@ Check whether variable is defined for a specific sub-type.
"""
variable_defined_for(var::Variable, ::Val{T}) where {T <: AbstractBlock} = hasmethod(var.f, (T, ))
all_variables = Dict{Symbol, AnyVariable}()
struct _Variables
variables :: Dict{Symbol, AnyVariable}
end
variables = _Variables(Dict{Symbol, AnyVariable}())
Base.getindex(v::_Variables, i::Symbol) = getfield(v, :variables)[i]
Base.keys(v::_Variables) = keys(getfield(v, :variables))
macro defvar(func_def)
Base.propertynames(v::_Variables) = Tuple(keys(getfield(v, :variables)))
Base.getproperty(v::_Variables, s::Symbol) = v[s]
macro defvar(func_def)
return _defvar(func_def, nothing)
end
macro defvar(getter, func_def)
return _defvar(func_def, getter)
end
function _defvar(func_def, getter=nothing)
func_names = []
if getter isa Symbol
getter_dict = Dict(
:pulse => get_pulse,
:gradient => get_gradient,
:pathway => get_pathway,
:readout => get_readout,
)
if !(getter in keys(getter_dict))
error("label in `@defvar <label> <statement>` should be one of `pulse`/`gradient`/`pathway`/`readout`, not `$getter`")
end
getter = getter_dict[getter]
end
function adjust_function(ex)
if ex isa Expr && ex.head == :function && length(ex.args) == 1
push!(func_names, ex.args[1])
return :nothing
end
try
fn_def = MacroTools.splitdef(ex)
push!(func_names, fn_def[:name])
......@@ -109,18 +157,30 @@ macro defvar(func_def)
expressions = Expr[]
for func_name in func_names
push!(expressions, quote
$(esc(func_name)) = if $(QuoteNode(func_name)) in keys(all_variables)
all_variables[$(QuoteNode(func_name))]
$(esc(func_name)) = if $(QuoteNode(func_name)) in keys(variables)
variables[$(QuoteNode(func_name))]
else
function $(func_name) end
all_variables[$(QuoteNode(func_name))] = Variable($(QuoteNode(func_name)), $(func_name))
getfield(variables, :variables)[$(QuoteNode(func_name))] = Variable($(QuoteNode(func_name)), $(func_name), $getter)
end
if $(esc(func_name)) isa AlternateVariable
error("$($(esc(func_name)).name) is defined through $($(esc(func_name)).other_var). Please define that variable instead.")
end
if !isnothing($getter) && $(esc(func_name)).getter != $getter
if isnothing($(esc(func_name)).getter)
$(esc(func_name)).getter = $getter
else
name = $(esc(func_name)).name
error("$(name) is already defined as a variable for $($(esc(func_name)).getter). Cannot switch to $($getter).")
end
end
end
)
end
args = vcat([e.args for e in expressions]...)
return Expr(
:block,
expressions...,
args...,
new_func_def
)
end
......@@ -132,20 +192,16 @@ Duration of the sequence or building block in ms.
""" duration
struct AlternateVariable <: AnyVariable
name :: Symbol
other_var :: Symbol
to_other :: Function
from_other :: Function
inverse :: Bool
function def_alternate_variable!(name::Symbol, other_var::Symbol, to_other::Function, from_other::Function, inverse::Bool)
getfield(variables, :variables)[name] = AlternateVariable(name, other_var, to_other, from_other, inverse)
end
all_variables[:qval] = AlternateVariable(:qval, :qval_square, q->q^2, sqrt, false)
all_variables[:spoiler_scale] = AlternateVariable(:spoiler_scale, :spoiler_scale, q->1e-3 * 2π/q, l->1e-3 * 2π/l, true)
def_alternate_variable!(:qval, :qval_square, q->q^2, sqrt, false)
def_alternate_variable!(:spoiler_scale, :spoiler_scale, q->1e-3 * 2π/q, l->1e-3 * 2π/l, true)
for name in [:slice_thickness, :bandwidth, :fov, :voxel_size]
inv_name = Symbol("inverse_" * string(name))
all_variables[name] = AlternateVariable(name, inv_name, inv, inv, true)
def_alternate_variable!(name, inv_name, inv, inv, true)
end
for (name, alt_name) in [
......@@ -153,7 +209,7 @@ for (name, alt_name) in [
(:TR, :repetition_time),
(:Δ, :diffusion_time),
]
all_variables[name] = AlternateVariable(name, alt_name, identity, identity, false)
def_alternate_variable!(name, alt_name, identity, identity, false)
end
......@@ -195,7 +251,7 @@ function get_free_variable(::Val{:max}; kwargs...)
end
"""
get_pulse(building_block)]
get_pulse(building_block)
Get the pulse played out during the building block.
......@@ -204,7 +260,7 @@ Any `pulse` variables not explicitly defined for this building block will be pas
function get_pulse end
"""
get_gradient(building_block)]
get_gradient(building_block)
Get the gradient played out during the building block.
......@@ -213,7 +269,7 @@ Any `gradient` variables not explicitly defined for this building block will be
function get_gradient end
"""
get_readout(building_block)]
get_readout(building_block)
Get the readout played out during the building block.
......@@ -221,6 +277,15 @@ Any `readout` variables not explicitly defined for this building block will be p
"""
function get_readout end
"""
get_pathway(sequence)
Get the default spin pathway for the sequence.
Any `pathway` variables not explicitly defined for this building block will be passed on to the pathway.
"""
function get_pathway end
"""
bmat_gradient(gradient::GradientBlock, qstart=(0, 0, 0))
......@@ -243,76 +308,26 @@ function gradient_orientation end
function effective_time end
"""
VariableNotAvailable(building_block, variable, alt_variable)
Exception raised when a variable function does not support a specific `AbstractBlock`.
"""
mutable struct VariableNotAvailable <: Exception
bb :: Type{<:AbstractBlock}
variable :: Function
alt_variable :: Union{Nothing, Function}
end
VariableNotAvailable(bb::Type{<:AbstractBlock}, variable::Function) = VariableNotAvailable(bb, variable, nothing)
function Base.showerror(io::IO, e::VariableNotAvailable)
if isnothing(e.alt_variable)
print(io, e.variable, " is not available for block of type ", e.bb, ".")
else
print(io, e.variable, " is not available for block of type ", e.bb, ". Please use ", e.alt_variable, " instead to set any contsraints or objective functions.")
function (var::Type{Variable})(block::AbstractBlock, args...; kwargs...)
if !applicable(var.f, block, args...) && !isnothing(var.getter)
apply_to = var.getter(block)
if apply_to isa AbstractBlock
return var(apply_to, args...; kwargs...)
elseif apply_to isa NamedTuple
return NamedTuple(k => var(v, args...; kwargs...) for (k, v) in pairs(apply_to))
elseif apply_to isa AbstractVector{<:AbstractBlock} || apply_to isa Tuple
return var.(apply_to, args...; kwargs...)
end
end
return var.f(block, args...; kwargs...)
end
for (target_name, all_vars) in all_variables_symbols
for (variable_func, _) in all_vars
if variable_func in [:qval3, :TR, :TE, :Δ]
continue
end
get_func = Symbol("get_" * string(target_name))
use_get_func = target_name in (:pulse, :readout, :gradient)
@eval function Variables.$variable_func(bb::AbstractBlock)
try
if Variables.$variable_func in keys(alternative_variables)
alt_var, forward, backward, _ = alternative_variables[Variables.$variable_func]
try
value = alt_var(bb)
if value isa Number
return backward(value)
elseif value isa AbstractArray{<:Number}
return backward.(value)
end
catch e
if e isa VariableNotAvailable
throw(VariableNotAvailable(typeof(bb), Variables.$variable_func))
end
rethrow()
end
throw(VariableNotAvailable(typeof(bb), Variables.$variable_func, alt_var))
end
throw(VariableNotAvailable(typeof(bb), Variables.$variable_func))
catch e
if $use_get_func && e isa VariableNotAvailable && hasmethod($get_func, tuple(typeof(bb)))
apply_to = try
$(get_func)(bb)
catch
throw(VariableNotAvailable(typeof(bb), Variables.$variable_func))
end
if apply_to isa AbstractBlock
return Variables.$variable_func(apply_to)
elseif apply_to isa NamedTuple
return NamedTuple(k => Variables.$variable_func(v) for (k, v) in pairs(apply_to))
elseif apply_to isa AbstractVector{<:AbstractBlock} || apply_to isa Tuple
return Variables.$variable_func.(apply_to)
end
error("$($(use_get_func)) returned an unexpected type for $(bb).")
end
rethrow()
end
end
end
function (var::Type{AlternateVariable})(args...; kwargs...)
other_var = variables[var.other_var]
apply_from_other(res::Number) = var.from_other(res)
apply_from_other(res::AbstractArray{<:Number}) = var.from_other.(res)
apply_from_other(res::NamedTuple) = NamedTuple(k => apply_from_other(v) for (k, v) in pairs(res))
return apply_from_other(other_var(args...; kwargs...))
end
......@@ -346,25 +361,23 @@ If set to a numeric value, a constraint will be added to fix the function value
If set to `:min` or `:max`, minimising or maximising this function will be added to the cost function.
"""
function set_simple_constraints!(block::AbstractBlock, kwargs)
for (key, value) in kwargs
if variables[key] in keys(alternative_variables)
alt_var, forward, backward, to_invert = alternative_variables[variables[key]]
invert_value(value::VariableType) = forward(value)
real_kwargs = NamedTuple(key => value for (key, value) in kwargs if !isnothing(value))
for (key, value) in real_kwargs
var = variables[key]
if var isa AlternateVariable
if var.other_var in real_kwargs
error("Set constraints on both $key and $(var.other_var), however they are equivalent.")
end
invert_value(value::VariableType) = var.to_other(value)
invert_value(value::Symbol) = invert_value(Val(value))
invert_value(::Val{:min}) = to_invert ? Val(:max) : Val(:min)
invert_value(::Val{:max}) = to_invert ? Val(:min) : Val(:max)
invert_value(::Val{:min}) = var.inverse ? Val(:max) : Val(:min)
invert_value(::Val{:max}) = var.inverse ? Val(:min) : Val(:max)
invert_value(value::AbstractVector) = invert_value.(value)
invert_value(value) = value
try
apply_simple_constraint!(alt_var(block), invert_value(value))
continue
catch e
if !(e isa VariableNotAvailable)
rethrow()
end
end
apply_simple_constraint!(variables[var.other_var](block), invert_value(value))
else
apply_simple_constraint!(var(block), value)
end
apply_simple_constraint!(variables[key](block), value)
end
nothing
end
......@@ -381,7 +394,6 @@ Add a single constraint or objective to the `variable`.
- `number`: fix variable to this value
- `equation`: fix variable to the result of this equation
"""
apply_simple_constraint!(variable::VariableType, ::Nothing) = nothing
apply_simple_constraint!(variable::AbstractVector, value::Symbol) = apply_simple_constraint!(sum(variable), Val(value))
apply_simple_constraint!(variable::VariableType, value::Symbol) = apply_simple_constraint!(variable, Val(value))
apply_simple_constraint!(variable::VariableType, ::Val{:min}) = @objective global_model() Min objective_function(global_model()) + variable
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment