From f4f4e56ea5a2d35607d85089f7ea63650cf90dc5 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Thu, 25 Apr 2024 13:28:01 +0100
Subject: [PATCH] Add adjust function

---
 Project.toml                                  |  1 +
 src/MRIBuilder.jl                             |  4 +
 .../changing_gradient_blocks.jl               | 26 +++++-
 .../constant_gradient_blocks.jl               | 28 ++++++-
 src/components/instant_gradients.jl           | 25 +++++-
 src/components/pulses/constant_pulses.jl      | 11 ++-
 src/components/pulses/generic_pulses.jl       |  9 +-
 src/components/pulses/instant_pulses.jl       |  6 +-
 src/components/pulses/sinc_pulses.jl          | 15 +++-
 src/post_hoc.jl                               | 82 +++++++++++++++++++
 src/variables.jl                              | 12 +++
 11 files changed, 212 insertions(+), 7 deletions(-)
 create mode 100644 src/post_hoc.jl

diff --git a/Project.toml b/Project.toml
index 88ef04c..734bcb6 100644
--- a/Project.toml
+++ b/Project.toml
@@ -14,6 +14,7 @@ MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
 Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
 Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
 QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
+Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
 Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
 StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
 
diff --git a/src/MRIBuilder.jl b/src/MRIBuilder.jl
index 65f1070..f05ab1e 100644
--- a/src/MRIBuilder.jl
+++ b/src/MRIBuilder.jl
@@ -10,6 +10,7 @@ include("components/components.jl")
 include("containers/containers.jl")
 include("pathways.jl")
 include("parts/parts.jl")
+include("post_hoc.jl")
 include("sequences/sequences.jl")
 include("printing.jl")
 include("sequence_io/sequence_io.jl")
@@ -36,6 +37,9 @@ export Pathway, duration_transverse, duration_dephase, bval, bmat, get_pathway
 import .Parts: dwi_gradients, readout_event, excitation_pulse, refocus_pulse, Trapezoid, SliceSelect, LineReadout, opposite_kspace_lines, SpoiltSliceSelect, SliceSelectRephase, EPIReadout, interpret_image_size
 export dwi_gradients, readout_event, excitation_pulse, refocus_pulse, Trapezoid, SliceSelect, LineReadout, opposite_kspace_lines, SpoiltSliceSelect, SliceSelectRephase, EPIReadout, interpret_image_size
 
+import .PostHoc: adjust
+export adjust
+
 import .Sequences: GradientEcho, SpinEcho, DiffusionSpinEcho, DW_SE, DWI
 export GradientEcho, SpinEcho, DiffusionSpinEcho, DW_SE, DWI
 
diff --git a/src/components/gradient_waveforms/changing_gradient_blocks.jl b/src/components/gradient_waveforms/changing_gradient_blocks.jl
index 5d59d7d..7dc1409 100644
--- a/src/components/gradient_waveforms/changing_gradient_blocks.jl
+++ b/src/components/gradient_waveforms/changing_gradient_blocks.jl
@@ -1,6 +1,8 @@
 module ChangingGradientBlocks
 import StaticArrays: SVector
-import ....Variables: VariableType, duration, qval, bmat_gradient, gradient_strength, slew_rate, get_free_variable
+import Rotations: RotMatrix3
+import LinearAlgebra: I
+import ....Variables: VariableType, duration, qval, bmat_gradient, gradient_strength, slew_rate, get_free_variable, adjust_internal
 import ...AbstractTypes: GradientWaveform
 
 
@@ -99,5 +101,27 @@ function split_gradient(cgb::ChangingGradient, times::VariableType...)
     end
 end
 
+function adjust_internal(cgb::ChangingGradient1D; orientation=nothing, scale=1., rotation=nothing)
+    if !isnothing(orientation) && !isnothing(rotation)
+        error("Cannot set both the gradient orientation and rotation.")
+    end
+    new_orientation = isnothing(orientation) ? (isnothing(rotation) ? cgb.orientation : rotation * cgb.orientation) : orientation
+    return ChangingGradient1D(
+        cgb.gradient_strength_start * scale,
+        cgb.slew_rate * scale,
+        new_orientation,
+        cgb.duration,
+        cgb.group
+    )
+end
+
+function adjust_internal(cgb::ChangingGradient3D; scale=1., rotation=RotMatrix3(I(3)))
+    return ChangingGradient3D(
+        rotation * (cgb.gradient_strength_start .* scale),
+        rotation * (cgb.slew_rate .* scale),
+        cgb.duration,
+        cgb.group
+    )
+end
 
 end
diff --git a/src/components/gradient_waveforms/constant_gradient_blocks.jl b/src/components/gradient_waveforms/constant_gradient_blocks.jl
index b42a69f..c411c84 100644
--- a/src/components/gradient_waveforms/constant_gradient_blocks.jl
+++ b/src/components/gradient_waveforms/constant_gradient_blocks.jl
@@ -1,6 +1,6 @@
 module ConstantGradientBlocks
 import StaticArrays: SVector
-import ....Variables: VariableType, duration, qval, bmat_gradient, gradient_strength, slew_rate, get_free_variable
+import ....Variables: VariableType, duration, qval, bmat_gradient, gradient_strength, slew_rate, get_free_variable, adjust_internal
 import ...AbstractTypes: GradientWaveform
 import ..ChangingGradientBlocks: split_gradient
 
@@ -72,4 +72,30 @@ function split_gradient(cgb::ConstantGradient, times::VariableType...)
     end
 end
 
+function adjust_internal(cgb::ConstantGradient1D; orientation=nothing, scale=1., rotation=nothing)
+    if !isnothing(orientation) && !isnothing(rotation)
+        error("Cannot set both the gradient orientation and rotation.")
+    end
+    new_orientation = isnothing(orientation) ? (isnothing(rotation) ? cgb.orientation : rotation * cgb.orientation) : orientation
+    return ConstantGradient1D(
+        cgb.gradient_strength * scale,
+        new_orientation,
+        cgb.duration,
+        cgb.group
+    )
+end
+
+function adjust_internal(cgb::ConstantGradient3D; scale=1., rotation=nothing)
+    return ConstantGradient3D(
+        (
+            isnothing(rotation) ? 
+            (cgb.gradient_strength .* scale) : 
+            (rotation * (cgb.gradient_strength .* scale))
+        ),
+        cgb.duration,
+        cgb.group
+    )
+end
+
+
 end
diff --git a/src/components/instant_gradients.jl b/src/components/instant_gradients.jl
index 989e894..ba0eddc 100644
--- a/src/components/instant_gradients.jl
+++ b/src/components/instant_gradients.jl
@@ -1,7 +1,7 @@
 module InstantGradients
 import StaticArrays: SVector, SMatrix
 import JuMP: @constraint
-import ...Variables: VariableType, duration, qval, bmat_gradient, get_free_variable, set_simple_constraints!, effective_time, make_generic
+import ...Variables: VariableType, duration, qval, bmat_gradient, get_free_variable, set_simple_constraints!, effective_time, make_generic, adjust_internal
 import ...BuildSequences: global_model
 import ..AbstractTypes: EventComponent, GradientWaveform
 
@@ -70,4 +70,27 @@ bmat_gradient(::InstantGradient, qstart=nothing) = zero(SMatrix{3, 3, Float64, 3
 
 make_generic(ig::InstantGradient) = ig
 
+function adjust_internal(ig::InstantGradient1D; orientation=nothing, scale=1., rotation=nothing)
+    if !isnothing(orientation) && !isnothing(rotation)
+        error("Cannot set both the gradient orientation and rotation.")
+    end
+    new_orientation = isnothing(orientation) ? (isnothing(rotation) ? ig.orientation : rotation * ig.orientation) : orientation
+    return InstantGradient1D(
+        ig.qval * scale,
+        new_orientation,
+        ig.group
+    )
+end
+
+function adjust_internal(ig::InstantGradient3D; scale=1., rotation=nothing)
+    return InstantGradient3D(
+        (
+            isnothing(rotation) ? 
+            (ig.qvec .* scale) : 
+            (rotation * (ig.qvec .* scale))
+        ),
+        ig.group
+    )
+end
+
 end
\ No newline at end of file
diff --git a/src/components/pulses/constant_pulses.jl b/src/components/pulses/constant_pulses.jl
index fc692b0..a15ea30 100644
--- a/src/components/pulses/constant_pulses.jl
+++ b/src/components/pulses/constant_pulses.jl
@@ -2,7 +2,7 @@ module ConstantPulses
 import JuMP: @constraint
 import ...AbstractTypes: RFPulseComponent, split_timestep
 import ....BuildSequences: global_model
-import ....Variables: duration, amplitude, effective_time, flip_angle, phase, inverse_bandwidth, VariableType, set_simple_constraints!, frequency, make_generic, get_free_variable
+import ....Variables: duration, amplitude, effective_time, flip_angle, phase, inverse_bandwidth, VariableType, set_simple_constraints!, frequency, make_generic, get_free_variable, adjust_internal
 import ..GenericPulses: GenericPulse
 
 """
@@ -64,5 +64,14 @@ end
 
 split_timestep(pulse::ConstantPulse, precision) = Inf
 
+function adjust_internal(block::ConstantPulse; scale=1., frequency=0.)
+    ConstantPulse(
+        block.amplitude * scale,
+        block.duration,
+        block.phase,
+        block.frequency + frequency,
+        block.group,
+    )
+end
 
 end
\ No newline at end of file
diff --git a/src/components/pulses/generic_pulses.jl b/src/components/pulses/generic_pulses.jl
index 74ded02..856bf8a 100644
--- a/src/components/pulses/generic_pulses.jl
+++ b/src/components/pulses/generic_pulses.jl
@@ -2,7 +2,7 @@ module GenericPulses
 
 import Polynomials: fit
 import ...AbstractTypes: RFPulseComponent, split_timestep
-import ....Variables: duration, amplitude, effective_time, flip_angle, make_generic, phase, frequency
+import ....Variables: duration, amplitude, effective_time, flip_angle, make_generic, phase, frequency, adjust_internal
 
 
 """
@@ -153,5 +153,12 @@ function split_timestep(gp::GenericPulse, precision)
     return sqrt(2 * precision / max_second_der)
 end
 
+function adjust_internal(block::GenericPulse; scale=1., frequency=0.)
+    GenericPulse(
+        block.time,
+        block.amplitude .* scale,
+        block.phase .+ (360. * frequency) .* (block.time .- effective_time(block))
+    )
+end
 
 end
diff --git a/src/components/pulses/instant_pulses.jl b/src/components/pulses/instant_pulses.jl
index 3675969..367c346 100644
--- a/src/components/pulses/instant_pulses.jl
+++ b/src/components/pulses/instant_pulses.jl
@@ -2,7 +2,7 @@ module InstantPulses
 import JuMP: @constraint
 import ...AbstractTypes: RFPulseComponent
 import ....BuildSequences: global_model
-import ....Variables: duration, amplitude, effective_time, flip_angle, phase, inverse_bandwidth, VariableType, make_generic, get_free_variable
+import ....Variables: duration, amplitude, effective_time, flip_angle, phase, inverse_bandwidth, VariableType, make_generic, get_free_variable, adjust_internal
 
 """
     InstantPulse(; flip_angle=nothing, phase=nothing, group=nothing)
@@ -41,4 +41,8 @@ inverse_bandwidth(::InstantPulse) = 0.
 
 make_generic(block::InstantPulse) = block
 
+function adjust_internal(block::InstantPulse; scale=1., frequency=0.)
+    InstantPulse(block.flip_angle * scale, block.phase, block.group)
+end
+
 end
\ No newline at end of file
diff --git a/src/components/pulses/sinc_pulses.jl b/src/components/pulses/sinc_pulses.jl
index 9f7d590..2d6cacd 100644
--- a/src/components/pulses/sinc_pulses.jl
+++ b/src/components/pulses/sinc_pulses.jl
@@ -2,7 +2,7 @@ module SincPulses
 import JuMP: @constraint
 import QuadGK: quadgk
 import ....BuildSequences: global_model
-import ....Variables: duration, amplitude, effective_time, flip_angle, phase, inverse_bandwidth, VariableType, set_simple_constraints!, frequency, make_generic, get_free_variable
+import ....Variables: duration, amplitude, effective_time, flip_angle, phase, inverse_bandwidth, VariableType, set_simple_constraints!, frequency, make_generic, get_free_variable, adjust_internal
 import ...AbstractTypes: RFPulseComponent, split_timestep
 import ..GenericPulses: GenericPulse
 
@@ -107,4 +107,17 @@ function split_timestep(block::SincPulse, precision)
     return sqrt(2 * precision / max_second_derivative)
 end
 
+function adjust_internal(block::SincPulse; scale=1., frequency=0.)
+    SincPulse(
+        block.apodise,
+        block.Nzeros,
+        block.norm_flip_angle,
+        block.amplitude * scale,
+        block.phase,
+        block.frequency + frequency,
+        block.lobe_duration,
+        block.group
+    )
+end
+
 end
\ No newline at end of file
diff --git a/src/post_hoc.jl b/src/post_hoc.jl
new file mode 100644
index 0000000..aaca892
--- /dev/null
+++ b/src/post_hoc.jl
@@ -0,0 +1,82 @@
+"""
+Define post-fitting adjustments of the sequences
+"""
+module PostHoc
+
+import ..Variables: AbstractBlock, adjust_internal
+import ..Components: GradientWaveform, RFPulseComponent, BaseComponent, NoGradient
+import ..Containers: ContainerBlock
+
+"""
+    adjust(block; kwargs...)
+
+Generate a new sequence/building_block/component with some post-fitting adjustments.
+
+The following adjustments are allowed:
+- for MR gradients
+    - `orientation`: set the orientation to a given vector.
+    - `rotation`: rotate the gradient orientations using a rotations from [`Rotations.jl`](https://juliageometry.github.io/Rotations.jl/stable/).
+    - `scale`: multiply the gradient strength by the given value. Note that if you use a value not between -1 and 1 you might break the scanner's maximum gradient or slew rate.
+- for RF pulses:
+    - `frequency`: shift the off-resonance frequency by the given value (in kHz).
+    - `scale`: multiply the RF pulse amplitude by the given value (used to model the B1 transmit field).
+
+Specific sequence components that can be adjusted are identified by their `group` name.
+For example, `adjust(sequence, dwi=(orientation=[0, 1, 0], ))` will set any gradient in the group `dwi` to point in the y-direction.
+
+To affect all gradients or pulses, use `gradient=` or `pulse`, e.g.
+`adjust(sequence, pulse=(scale=0.5, ))`
+will divide the amplitude of all RV pulses by two.
+"""
+function adjust(block::AbstractBlock; kwargs...) 
+    used_names = Set{Symbol}()
+    res = adjust_helper(block, used_names; kwargs...)
+    unused_names = filter(keys(kwargs)) do key
+        !(key in used_names)
+    end
+    if length(unused_names) > 0
+        @warn "Some group/type names were not used in call to `MRIBuilder.adjust`, namely: $(unused_names)."
+    end
+    res
+end
+
+function adjust_helper(container::ContainerBlock, used_names::Set{Symbol}; kwargs...)
+    params = []
+    for prop_name in propertynames(container)
+        push!(params, adjust_helper(getproperty(container, prop_name), used_names; kwargs...))
+    end
+    return typeof(container)(params...)
+end
+
+adjust_helper(some_value, used_names::Set{Symbol}; kwargs...) = some_value
+adjust_helper(array_variable::AbstractArray, used_names::Set{Symbol}; kwargs...) = map(array_variable) do v adjust_helper(v, used_names; kwargs...) end
+adjust_helper(dict_variable::AbstractDict, used_names::Set{Symbol}; kwargs...) = typeof(dict_variable)(k => adjust_helper(v, used_names; kwargs...) for (k, v) in pairs(dict_variable))
+adjust_helper(tuple_variable::Tuple, used_names::Set{Symbol}; kwargs...) = map(tuple_variable) do v adjust_helper(v, used_names; kwargs...) end
+adjust_helper(pair:: Pair, used_names::Set{Symbol}; kwargs...) = adjust_helper(pair[1], used_names; kwargs...) => adjust_helper(pair[2], used_names; kwargs...)
+
+adjust_helper(block::NoGradient, used_names::Set{Symbol}; kwargs...) = block
+
+function adjust_helper(block::GradientWaveform, used_names::Set{Symbol}; gradient=(), kwargs...)
+    if !isnothing(block.group) && (block.group in keys(kwargs))
+        push!(used_names, block.group)
+        return adjust_internal(block; kwargs[block.group]...)
+    else
+        push!(used_names, :gradient)
+        return adjust_internal(block; gradient...)
+    end
+    return new_block
+end
+
+function adjust_helper(block::RFPulseComponent, used_names::Set{Symbol}; pulse=(), kwargs...)
+    if !isnothing(block.group) && (block.group in keys(kwargs))
+        push!(used_names, block.group)
+        return adjust_internal(block; kwargs[block.group]...)
+    else
+        push!(used_names, :pulse)
+        return adjust_internal(block; pulse...)
+    end
+end
+
+adjust_helper(block::BaseComponent) = block
+
+end
\ No newline at end of file
diff --git a/src/variables.jl b/src/variables.jl
index 591385e..b93eb39 100644
--- a/src/variables.jl
+++ b/src/variables.jl
@@ -30,6 +30,18 @@ function fixed(ab::AbstractBlock)
     return typeof(ab)(params...)
 end
 
+
+"""
+    adjust_internal(block, names_used; kwargs...)
+
+Returns the adjusted blocks and add any keywords used in the process to `names_used`.
+
+This is a helper function used by [`adjust`](@ref). 
+"""
+function adjust_internal end
+
+
+
 all_variables_symbols = [
     :block => [
         :duration => "duration of the building block in ms.",
-- 
GitLab