From 92d4820f605c28cfa956fbc473bc92aeba79080e Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Sun, 28 Apr 2024 17:17:03 +0100
Subject: [PATCH] Generated multiple sequences in adjust

---
 src/post_hoc.jl       | 57 ++++++++++++++++++++++++++++++++++++++++---
 test/test_post_hoc.jl | 16 ++++++++++++
 2 files changed, 70 insertions(+), 3 deletions(-)

diff --git a/src/post_hoc.jl b/src/post_hoc.jl
index b6d1592..55ae5a5 100644
--- a/src/post_hoc.jl
+++ b/src/post_hoc.jl
@@ -10,7 +10,7 @@ import ..Containers: ContainerBlock, Sequence, Wait
 """
     adjust(block; kwargs...)
 
-Generate a new sequence/building_block/component with some post-fitting adjustments.
+Generate one or more new sequences/building_blocks/components with some post-fitting adjustments.
 
 The following adjustments are allowed:
 - for MR gradients
@@ -21,6 +21,11 @@ The following adjustments are allowed:
     - `frequency`: shift the off-resonance frequency by the given value (in kHz).
     - `scale`: multiply the RF pulse amplitude by the given value (used to model the B1 transmit field).
 
+A vector of multiple values can be passed on to any of these in order to create multiple sequences with different adjustments.
+The will usually be merged together. You can get the individual sequences by passing on `merge=false`.
+The time between these repeated sequences can be adjusted using the keywords described in [`merge_sequences`](@ref) passed on to the merge keyword:
+e.g., `merge=(wait_time=10, )` adds a wait time of 10 ms between each repeated sequence.
+
 Specific sequence components that can be adjusted are identified by their `group` name.
 For example, `adjust(sequence, dwi=(orientation=[0, 1, 0], ))` will set any gradient in the group `dwi` to point in the y-direction.
 
@@ -28,9 +33,20 @@ To affect all gradients or pulses, use `gradient=` or `pulse`, e.g.
 `adjust(sequence, pulse=(scale=0.5, ))`
 will divide the amplitude of all RV pulses by two.
 """
-function adjust(block::AbstractBlock; kwargs...) 
+function adjust(block::AbstractBlock; merge=true, kwargs...) 
     used_names = Set{Symbol}()
-    res = adjust_helper(block, used_names; kwargs...)
+    n_adjust, kwargs_list = adjust_kwargs_list(; kwargs...)
+    if isnothing(n_adjust)
+        res = adjust_helper(block, used_names; kwargs_list[1]...)
+    else
+        res = [adjust_helper(block, used_names; kw...) for kw in kwargs_list]
+        if merge !== false
+            if merge === true
+                merge = NamedTuple()
+            end
+            res = merge_sequences(res...; merge...)
+        end
+    end
     unused_names = filter(keys(kwargs)) do key
         !(key in used_names)
     end
@@ -40,6 +56,41 @@ function adjust(block::AbstractBlock; kwargs...)
     res
 end
 
+function adjust_kwargs_list(; kwargs...)
+    n_adjust = nothing
+    for (_, named_tuple) in kwargs
+        for key in keys(named_tuple)
+            value = named_tuple[key]
+            if key == :orientation && value isa AbstractVector{<:Number}
+                continue
+            end
+            if value isa AbstractVector
+                if isnothing(n_adjust)
+                    n_adjust = length(value)
+                else
+                    @assert length(value) == n_adjust
+                end
+            end
+        end
+    end
+    use_n_adjust = isnothing(n_adjust) ? 1 : n_adjust
+    kwargs_list = [Dict{Symbol, Any}([field=>Dict{Symbol, Any}() for field in keys(kwargs)]...) for _ in 1:use_n_adjust]
+    for (field, named_tuple) in kwargs
+        for key in keys(named_tuple)
+            value = named_tuple[key]
+            for index in 1:use_n_adjust
+                if (key == :orientation && value isa AbstractVector{<:Number}) || !(value isa AbstractVector)
+                    kwargs_list[index][field][key] = value
+                else
+                    kwargs_list[index][field][key] = value[index]
+                end
+            end
+        end
+    end
+    return (n_adjust, kwargs_list)
+end
+
+
 function adjust_helper(block::AbstractBlock, used_names::Set{Symbol}; gradient=(), pulse=(), kwargs...)
     params = []
     adjust_type = adjustable(block)
diff --git a/test/test_post_hoc.jl b/test/test_post_hoc.jl
index 83ced17..18e9d0d 100644
--- a/test/test_post_hoc.jl
+++ b/test/test_post_hoc.jl
@@ -29,11 +29,27 @@
                 new_dwi = adjust(dwi, DWI=(scale=0.5, orientation=[0., 1., 0.]))
                 @test bval(new_dwi) ≈ 0.25
                 @test all(qval3(new_dwi[:gradient]) .≈ [0., qval_orig/2, 0.])
+
+                @testset "multiple adjustments" begin
+                    new_dwi = adjust(dwi, DWI=(scale=[0.5, 1.], orientation=[0., 1., 0.]), merge=false)
+                    @test length(new_dwi) == 2
+                    @test bval(new_dwi[1]) ≈ 0.25
+                    @test bval(new_dwi[2]) ≈ 1.
+                    @test all(qval3(new_dwi[1][:gradient]) .≈ [0., qval_orig/2, 0.])
+                    @test all(qval3(new_dwi[2][:gradient]) .≈ [0., qval_orig, 0.])
+                    
+                    new_dwi = adjust(dwi, DWI=(scale=[0.5, 1.], orientation=[0., 1., 0.]))
+                    @test duration(new_dwi) ≈ 160
+
+                    new_dwi = adjust(dwi, DWI=(scale=[0.5, 1.], orientation=[0., 1., 0.]), merge=(wait_time=10, ))
+                    @test duration(new_dwi) ≈ 170
+                end
             end
             @testset "Rotate gradient" begin
                 new_dwi = adjust(dwi, gradient=(rotation=RotationVec(0., 0., π/4), ))
                 @test bval(new_dwi) ≈ 1.
                 @test all(qval3(new_dwi[:gradient]) .≈ [qval_orig/√2, qval_orig/√2, 0.])
+
             end
         end
     end
-- 
GitLab