From 25a97cb0f837e737b033971e45a76948a5757395 Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <MichielCottaar@protonmail.com> Date: Fri, 24 May 2024 15:38:29 +0100 Subject: [PATCH] Convert containers to use @defvar --- src/MRIBuilder.jl | 6 +-- src/containers/abstract.jl | 22 +++++---- src/containers/alternatives.jl | 4 +- src/containers/base_sequences.jl | 19 +++++--- src/containers/building_blocks.jl | 78 ++++++++++++++++--------------- 5 files changed, 71 insertions(+), 58 deletions(-) diff --git a/src/MRIBuilder.jl b/src/MRIBuilder.jl index a6b40dc..ed57857 100644 --- a/src/MRIBuilder.jl +++ b/src/MRIBuilder.jl @@ -7,7 +7,7 @@ include("scanners.jl") include("build_sequences.jl") include("variables.jl") include("components/components.jl") -#include("containers/containers.jl") +include("containers/containers.jl") #include("pathways.jl") #include("parts/parts.jl") #include("post_hoc.jl") @@ -28,8 +28,8 @@ export variables, effective_time, make_generic, @defvar, duration import .Components: InstantPulse, ConstantPulse, SincPulse, GenericPulse, InstantGradient, SingleReadout, ADC, CompositePulse, edge_times export InstantPulse, ConstantPulse, SincPulse, GenericPulse, InstantGradient, SingleReadout, ADC, CompositePulse, edge_times -#import .Containers: ContainerBlock, start_time, end_time, waveform, waveform_sequence, events, BaseBuildingBlock, BuildingBlock, Wait, BaseSequence, nrepeat, Sequence, AlternativeBlocks, match_blocks!, get_index_single_TR, readout_times, iter_blocks, iter_instant_gradients, iter_instant_pulses -#export ContainerBlock, start_time, end_time, waveform, waveform_sequence, events, BaseBuildingBlock, BuildingBlock, Wait, BaseSequence, nrepeat, Sequence, AlternativeBlocks, match_blocks!, get_index_single_TR, readout_times, iter_blocks, iter_instant_gradients, iter_instant_pulses +import .Containers: ContainerBlock, start_time, end_time, waveform, waveform_sequence, events, BaseBuildingBlock, BuildingBlock, Wait, BaseSequence, nrepeat, Sequence, AlternativeBlocks, match_blocks!, get_index_single_TR, readout_times, iter_blocks, iter_instant_gradients, iter_instant_pulses +export ContainerBlock, start_time, end_time, waveform, waveform_sequence, events, BaseBuildingBlock, BuildingBlock, Wait, BaseSequence, nrepeat, Sequence, AlternativeBlocks, match_blocks!, get_index_single_TR, readout_times, iter_blocks, iter_instant_gradients, iter_instant_pulses #import .Pathways: Pathway, duration_transverse, duration_dephase, bval, bmat, get_pathway #export Pathway, duration_transverse, duration_dephase, bval, bmat, get_pathway diff --git a/src/containers/abstract.jl b/src/containers/abstract.jl index 445d24e..d4ed89e 100644 --- a/src/containers/abstract.jl +++ b/src/containers/abstract.jl @@ -1,6 +1,5 @@ module Abstract -import ...Variables: AbstractBlock, duration, effective_time, gradient_strength, amplitude, phase, VariableType, get_pulse, get_gradient -import ...Components.Readouts: readout_times +import ...Variables: AbstractBlock, variables, VariableType, get_pulse, get_gradient, @defvar import ...Components: BaseComponent, InstantPulse, InstantGradient, ReadoutComponent, NoGradient, RFPulseComponent, GradientWaveform """ @@ -34,6 +33,10 @@ end_time(container::ContainerBlock, index, indices...) = start_time(container, i end_time(block::AbstractBlock) = duration(block) end_time(block::Tuple{<:VariableType, <:AbstractBlock}) = duration(block[2]) +@defvar begin + effective_time(container::ContainerBlock, index, indices...) = start_time(container, index) + effective_time(container[index], indices...) + effective_time(block::Tuple{<:VariableType, <:AbstractBlock}) = block[1] + effective_time(block[2]) +end """ effective_time(container, indices...) @@ -43,8 +46,7 @@ This will crash if the component does not have an [`effective_time`](@ref) (e.g. Also see [`duration`](@ref), [`start_time`](@ref), and [`end_time`](@ref) """ -effective_time(container::ContainerBlock, index, indices...) = start_time(container, index) + effective_time(container[index], indices...) -effective_time(block::Tuple{<:VariableType, <:AbstractBlock}) = block[1] + effective_time(block[2]) +effective_time """ @@ -52,7 +54,7 @@ effective_time(block::Tuple{<:VariableType, <:AbstractBlock}) = block[1] + effec Returns the gradient strength at a particular time within the sequence. """ -function gradient_strength end +gradient_strength = variables.gradient_strength """ @@ -60,7 +62,7 @@ function gradient_strength end Returns the RF amplitude at a particular time within the sequence in kHz. """ -function amplitude end +amplitude = variables.amplitude """ @@ -70,7 +72,7 @@ Returns the RF phase at a particular time within the sequence in degrees. NaN is returned if there is no pulse activate at that `time`. """ -function phase end +phase = variables.phase """ frequency(sequence, time) @@ -79,7 +81,7 @@ Returns the RF frequency at a particular time within the sequence in kHz. NaN is returned if there is no pulse activate at that `time`. """ -function frequency end +frequency = variables.frequency """ iter(sequence, get_type) @@ -139,12 +141,14 @@ This function will return a tuple with 2 elements: """ function get_gradient end +@defvar readout_times(container::ContainerBlock) = [time for (time, _) in iter(container, Val(:readout))] """ readout_times(sequence) Returns all the times that the sequence will readout. """ -readout_times(container::ContainerBlock) = [time for (time, _) in iter(container, Val(:readout))] +readout_times + iter(component::Tuple{<:Number, <:ReadoutComponent}, ::Val{:readout}) = [(time, nothing) for time in readout_times(component[2])] end \ No newline at end of file diff --git a/src/containers/alternatives.jl b/src/containers/alternatives.jl index 2b00577..0c75bbe 100644 --- a/src/containers/alternatives.jl +++ b/src/containers/alternatives.jl @@ -2,7 +2,7 @@ module Alternatives import JuMP: @constraint import ..Abstract: ContainerBlock import ...BuildSequences: global_model, fixed -import ...Variables: duration, make_generic +import ...Variables: @defvar, make_generic """ AlternativeBlocks(name, blocks) @@ -23,7 +23,7 @@ AlternativeBlocks(name::Symbol, options_vector::AbstractVector) = AlternativeBlo Base.getindex(alt::AlternativeBlocks, index) = alt.options[index] Base.length(alt::AlternativeBlocks) = length(alt.options) -duration(alt::AlternativeBlocks) = maximum(duration.(values(alt.options))) +@defvar duration(alt::AlternativeBlocks) = maximum(duration.(values(alt.options))) """ match_blocks!(alternatives, function) diff --git a/src/containers/base_sequences.jl b/src/containers/base_sequences.jl index 612eac8..5f6d45a 100644 --- a/src/containers/base_sequences.jl +++ b/src/containers/base_sequences.jl @@ -4,7 +4,7 @@ Defines [`BaseSequence`](@ref) and [`Sequence`](@ref) module BaseSequences import StaticArrays: SVector import JuMP: @constraint -import ...Variables: get_free_variable, repetition_time, VariableType, duration, variables, VariableNotAvailable, Variables, set_simple_constraints!, TR, make_generic, gradient_strength, amplitude, phase, gradient_strength3, get_gradient, get_pulse, frequency, gradient_orientation, get_gradient +import ...Variables: get_free_variable, VariableType, variables, set_simple_constraints!, make_generic, get_gradient, get_pulse, get_gradient, @defvar import ...BuildSequences: global_model, global_scanner import ...Components: EventComponent, NoGradient, edge_times import ...Scanners: Scanner, B0 @@ -102,9 +102,10 @@ nrepeat(bs::BaseSequence) = 1 repetition_time(bs::BaseSequence) = duration(bs) - -duration(bs::BaseSequence{0}) = 0. -duration(bs::BaseSequence) = sum(duration.(bs); init=0.) +@defvar begin + duration(bs::BaseSequence{0}) = 0. + duration(bs::BaseSequence) = sum(duration.(bs); init=0.) +end function edge_times(seq::BaseSequence; tol=1e-6) res = Float64[] @@ -120,13 +121,19 @@ function edge_times(seq::BaseSequence; tol=1e-6) return sort(unique_res) end -for fn in (:gradient_strength, :amplitude, :phase, :frequency, :gradient_strength3, :get_gradient, :get_pulse) +for fn in (:gradient_strength, :amplitude, :phase, :frequency, :gradient_strength3) @eval function $fn(sequence::BaseSequence, time::AbstractFloat) (block_time, block) = sequence(time) - return $fn(block, block_time) + return variables.$fn.f(block, block_time) end end +for fn in (:get_gradient, :get_pulse) + @eval function $fn(sequence::BaseSequence, time::AbstractFloat) + (block_time, block) = sequence(time) + return $fn(block, block_time) + end +end """ Sequence(blocks; name=:Sequence, variables...) Sequence(blocks...; name=:Sequence, variables...) diff --git a/src/containers/building_blocks.jl b/src/containers/building_blocks.jl index a983837..8849341 100644 --- a/src/containers/building_blocks.jl +++ b/src/containers/building_blocks.jl @@ -8,8 +8,7 @@ import StaticArrays: SVector import ..Abstract: ContainerBlock, start_time, readout_times, end_time, iter import ...BuildSequences: global_model import ...Components: BaseComponent, GradientWaveform, EventComponent, NoGradient, ChangingGradient, ConstantGradient, split_gradient, RFPulseComponent, ReadoutComponent, InstantGradient, edge_times -import ...Variables: qval, bmat_gradient, effective_time, get_free_variable, qval3, slew_rate, gradient_strength, amplitude, phase, frequency -import ...Variables: VariableType, duration, make_generic, get_pulse, get_readout, scanner_constraints!, get_gradient, gradient_orientation +import ...Variables: VariableType, make_generic, get_pulse, get_readout, scanner_constraints!, get_gradient, gradient_orientation, variables, @defvar """ Basic BuildingBlock, which can consist of a gradient waveforms with any number of RF pulses/readouts overlaid @@ -143,7 +142,7 @@ function start_time(building_block::BaseBuildingBlock, index) error("Building block with index '$index' not found") end -duration(bb::BaseBuildingBlock) = sum([duration(wv) for (_, wv) in waveform_sequence(bb)]) +@defvar duration(bb::BaseBuildingBlock) = sum([duration(wv) for (_, wv) in waveform_sequence(bb)]) # Pathway support """ @@ -191,6 +190,39 @@ function waveform_sequence(bb::BaseBuildingBlock, first, last) return parts end +@defvar begin + function qval(bb::BaseBuildingBlock, index1, index2) + if (!isnothing(index1)) && (index1 == index2) + return 0. + end + res = sum([qval(wv) for (_, wv) in waveform_sequence(bb, index1, index2)]) + + t1 = isnothing(index1) ? 0. : start_time(bb, index1) + t2 = isnothing(index2) ? duration(bb) : start_time(bb, index2) + for (key, event) in events(bb) + if event isa InstantGradient && (t1 <= start_time(bb, key) <= t2) + res = res .+ qval(event) + end + end + return res + end + qval(bb::BaseBuildingBlock) = qval(bb, nothing, nothing) + + function bmat_gradient(bb::BaseBuildingBlock, qstart, index1, index2) + if (!isnothing(index1)) && (index1 == index2) + return zeros(3, 3) + end + result = Matrix{VariableType}(zeros(3, 3)) + qcurrent = Vector{VariableType}(qstart) + + for (_, part) in waveform_sequence(bb, index1, index2) + result = result .+ bmat_gradient(part, qcurrent) + qcurrent = qcurrent .+ qval3(part, qcurrent) + end + return result + end + bmat_gradient(bb::BaseBuildingBlock, qstart) = bmat_gradient(bb, qstart, nothing, nothing) +end """ qval(overlapping[, first_event, last_event]) @@ -200,37 +232,7 @@ Computes the area under the curve for the gradient waveform in [`BaseBuildingBlo If `first_event` is set to something else than `nothing`, only the gradient waveform after this RF pulse/Readout will be considered. Similarly, if `last_event` is set to something else than `nothing`, only the gradient waveform up to this RF pulse/Readout will be considered. """ -function qval(bb::BaseBuildingBlock, index1, index2) - if (!isnothing(index1)) && (index1 == index2) - return 0. - end - res = sum([qval(wv) for (_, wv) in waveform_sequence(bb, index1, index2)]) - - t1 = isnothing(index1) ? 0. : start_time(bb, index1) - t2 = isnothing(index2) ? duration(bb) : start_time(bb, index2) - for (key, event) in events(bb) - if event isa InstantGradient && (t1 <= start_time(bb, key) <= t2) - res = res .+ qval(event) - end - end - return res -end -qval(bb::BaseBuildingBlock) = qval(bb, nothing, nothing) - -function bmat_gradient(bb::BaseBuildingBlock, qstart, index1, index2) - if (!isnothing(index1)) && (index1 == index2) - return zeros(3, 3) - end - result = Matrix{VariableType}(zeros(3, 3)) - qcurrent = Vector{VariableType}(qstart) - - for (_, part) in waveform_sequence(bb, index1, index2) - result = result .+ bmat_gradient(part, qcurrent) - qcurrent = qcurrent .+ qval3(part, qcurrent) - end - return result -end -bmat_gradient(bb::BaseBuildingBlock, qstart) = bmat_gradient(bb, qstart, nothing, nothing) +qval function edge_times(bb::BaseBuildingBlock) res = Float64[] @@ -254,7 +256,7 @@ function get_pulse(bb::BaseBuildingBlock, time::Number) end for (fn, default_value) in ((:amplitude, 0.), (:phase, NaN), (:frequency, NaN)) - @eval function $fn(bb::BaseBuildingBlock, time::Number) + @eval function variables.$fn.f(bb::BaseBuildingBlock, time::Number) pulse = get_pulse(bb, time) if isnothing(pulse) return $default_value @@ -273,7 +275,7 @@ function get_gradient(bb::BaseBuildingBlock, time::Number) error("$bb with duration $(duration(bb)) does not define a gradient at time $time.") end -function gradient_strength(bb::BaseBuildingBlock, time::Number) +@defvar function gradient_strength(bb::BaseBuildingBlock, time::Number) (grad, time) = get_gradient(bb, time) return gradient_strength(grad, time) end @@ -353,7 +355,7 @@ function get_readout(bb::BuildingBlock) error("BuildingBlock contains more than one readout. Not sure which one to return.") end -function effective_time(bb::BuildingBlock) +@defvar function effective_time(bb::BuildingBlock) index = [i for (i, r) in events(bb) if r isa Union{RFPulseComponent, ReadoutComponent}] if length(index) == 0 error("BuildingBlock does not contain any RF pulse or readout events, so `effective_time` is not defined.") @@ -380,7 +382,7 @@ struct Wait <: BaseBuildingBlock end end -duration(wb::Wait) = wb.duration +@defvar duration(wb::Wait) = wb.duration Base.keys(::Wait) = (Val(:empty),) Base.getindex(wb::Wait, ::Val{:empty}) = NoGradient(wb.duration) -- GitLab