From 9f83069fde5c90cee1b0dd66ce9ccf5b524cad82 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <MichielCottaar@protonmail.com>
Date: Thu, 16 May 2024 16:31:13 +0100
Subject: [PATCH] Write pulseq files

---
 src/sequence_io/pulseq.jl | 275 +++++++++++++++++++++++++++++---------
 1 file changed, 210 insertions(+), 65 deletions(-)

diff --git a/src/sequence_io/pulseq.jl b/src/sequence_io/pulseq.jl
index 95e0943..346b67d 100644
--- a/src/sequence_io/pulseq.jl
+++ b/src/sequence_io/pulseq.jl
@@ -1,31 +1,6 @@
 module Pulseq
 
 
-"""
-    read_pulseq(IO)
-
-Reads a sequence from a pulseq file (http://pulseq.github.io/).
-Pulseq files can be produced using matlab (http://pulseq.github.io/) or python (https://pypulseq.readthedocs.io/en/master/).
-"""
-function read_pulseq(io::IO)
-    sections = parse_pulseq_sections(io)
-    return parse_all_sections(sections)
-end
-
-"""
-    write_pulseq(IO, sequence)
-
-Writes a sequence to an output IO file.
-"""
-function write_pulseq(io::IO, sequence::PulseqSequence)
-    if sequence.version < v"1.4"
-        error("Can only write to pulseq version 1.4 or later.")
-    end
-    sections = gen_all_sections(sequence)
-    write_section.(io, sections)
-end
-
-
 """
     parse_pulseq_dict(line, names, dtypes)
 
@@ -115,7 +90,18 @@ function parse_pulseq_sections(io::IO)
     return sections
 end
 
-# Writing pulseq files
+function write_pulseq_section(io::IO, section::PulseqSection{T}) where {T}
+    title = uppercase(string(T))
+    write(io, "[$title]\n")
+    for line in section.content
+        if iszero(length(line)) || line[end] != '\n'
+            line = line * '\n'
+        end
+        write(io, line)
+    end
+    write(io, "\n")
+    write(io, "\n")
+end
 
 struct PulseqShape
     samples :: Vector{Float64}
@@ -148,8 +134,8 @@ struct PulseqTrapezoid <:AnyPulseqGradient
 end
 
 struct PulseqADC
-    nsamples :: Int
-    dwell_time :: Float64
+    num :: Int
+    dwell :: Float64
     delay :: Int
     frequency :: Float64
     phase :: Float64
@@ -176,6 +162,69 @@ struct PulseqSequence
     blocks:: Vector{PulseqBlock}
 end
 
+"""
+    PulseqComponents(sequence::PulseqSequence)
+
+Indentifies and lists all the unique components in the sequence.
+"""
+struct PulseqComponents
+    shapes:: Vector{PulseqShape}
+    pulses:: Vector{PulseqRFPulse}
+    grads:: Vector{AnyPulseqGradient}
+    adc:: Vector{PulseqADC}
+    extensions:: Vector{PulseqExtension}
+end
+
+PulseqComponents() = PulseqComponents(
+    PulseqShape[],
+    PulseqRFPulse[],
+    AnyPulseqGradient[],
+    PulseqADC[],
+    PulseqExtension[],
+)
+
+add_components!(::PulseqComponents, search_vec::Vector, ::Nothing) = 0
+function add_components!(comp::PulseqComponents, search_vec::Vector{<:T}, component::T) where {T}
+    for (i, c) in enumerate(search_vec)
+        if same_component(comp, c, component)
+            return i
+        end
+    end
+    push!(search_vec, component)
+    return length(search_vec)
+end
+
+same_component(::PulseqComponents, ::Any, ::Any) = false
+function same_component(comp::PulseqComponents, a::T, b::T) where {T}
+    for name in fieldnames(T)
+        v1 = getfield(a, name)
+        v2 = getfield(b, name)
+
+        if v1 isa PulseqShape
+            v1 = add_components!(comp, v1)
+        end
+        if v2 isa PulseqShape
+            v2 = add_components!(comp, v2)
+        end
+        if v1 != v2
+            return false
+        end
+    end
+    return true
+end
+
+
+add_components!(comp::PulseqComponents, ::Nothing) = 0
+function add_components!(comp::PulseqComponents, shape::PulseqShape)
+    for (i, s) in enumerate(comp.shapes)
+        if same_shape(s, shape)
+            return i
+        end
+    end
+    push!(comp.shapes, shape)
+    return length(comp.shapes)
+end
+same_shape(shape1::PulseqShape, shape2::PulseqShape) = length(shape1.samples) == length(shape2.samples) && all(shape1.samples .≈ shape2.samples)
 
 
 # I/O all sections
@@ -195,8 +244,16 @@ function parse_all_sections(sections:: Dict{String, PulseqSection})
 end
 
 function gen_all_sections(seq:: PulseqSequence)
-    sections = [gen_section(seq, Val(symbol)) for symbol in [:version, :definitions, :shapes, :rf, :gradients, :trap, :adc, :extensions, :blocks]]
-    return [section for section in sections if length(section.content) > 0]
+    sections = Dict{Symbol, PulseqSection}()
+    sections[:version] = gen_section(seq, Val(:version))
+    sections[:definitions] = gen_section(seq, Val(:definitions))
+
+    comp = PulseqComponents()
+    sections[:blocks] = gen_section(seq, comp, Val(:blocks))
+    for symbol in [:rf, :gradients, :trap, :adc, :shapes]
+        sections[symbol] = gen_section(comp, Val(symbol))
+    end
+    return sections
 end
 
 # Version I/O
@@ -301,36 +358,9 @@ function uncompress(compressed::CompressedPulseqShape)
     return PulseqShape(amplitudes)
 end
 
-same_shape(shape1::PulseqShape, shape2::PulseqShape) = length(shape1.samples) == length(shape2.samples) && all(shape1.samples .≈ shape2.samples)
-
-"""
-    shapes(sequence::PulseqSequence)
-
-Return all the unique shapes used in the MR sequence.
-"""
-function shapes(sequence:: PulseqSequence)
-    result = PulseqShape[]
-    for block in sequence.blocks
-        for new_shape in _all_shapes(block)
-            if !any(s -> same_shape(s, new_shape), result)
-                push!(result, new_shape)
-            end
-        end
-    end
-    return result
-end
-
-_all_shapes(block::PulseqBlock) = vcat(_all_shapes.([block.gx, block.gy, block.gz, block.rf])...)
-_all_shapes(block::PulseqRFPulse) = vcat(_all_shapes.([block.magnitude, block.phase, block.time])...)
-_all_shapes(block::PulseqGradient) = vcat(_all_shapes.([block.shape, block.time])...)
-_all_shapes(shape::PulseqShape) = PulseqShape[shape]
-_all_shapes(::PulseqTrapezoid) = PulseqShape[]
-_all_shapes(::Nothing) = PulseqShape[]
-
-
-function gen_section(seq:: PulseqSequence, ::Val{:shapes})
+function gen_section(comp:: PulseqComponents, ::Val{:shapes})
     res = PulseqSection{:shapes}(String[])
-    for (index, shape) in enumerate(shapes(seq))
+    for (index, shape) in enumerate(comp.shapes)
         append!(res.content, [
             "",
             "shape_id $index",
@@ -388,6 +418,24 @@ function parse_section(section::PulseqSection{:rf}; shapes::Dict{Int, PulseqShap
     return result
 end
 
+function gen_section(comp:: PulseqComponents, ::Val{:rf})
+    res = PulseqSection{:rf}(String[])
+    for (i, pulse) in enumerate(comp.pulses)
+        values = string.(Any[
+            i,
+            pulse.amplitude,
+            add_components!(comp, pulse.magnitude),
+            add_components!(comp, pulse.phase),
+            add_components!(comp, pulse.time),
+            pulse.delay,
+            pulse.frequency,
+            pulse.phase_offset
+        ])
+        push!(res.content, join(values, " "))
+    end
+    return res
+end
+
 function parse_section(section::PulseqSection{:gradients}; shapes::Dict{Int, PulseqShape}, version::VersionNumber, kwargs...)
     result = Dict{Int, PulseqGradient}()
     for line in section.content
@@ -415,16 +463,24 @@ function parse_section(section::PulseqSection{:gradients}; shapes::Dict{Int, Pul
     return result
 end
 
-function _get_shape_id(shapes, shape)
-    for (i, s) in enumerate(shapes)
-        if same_shape(s, shape)
-            return i
+function gen_section(comp:: PulseqComponents, ::Val{:gradients})
+    res = PulseqSection{:gradients}(String[])
+    for (i, grad) in enumerate(comp.grads)
+        if !(grad isa PulseqGradient)
+            continue
         end
+        values = string.(Any[
+            i,
+            grad.amplitude,
+            add_components!(comp, grad.shape),
+            add_components!(comp, grad.time),
+            grad.delay,
+        ])
+        push!(res.content, join(values, " "))
     end
-    error("Shape not found.")
+    return res
 end
 
-
 function parse_section(section::PulseqSection{:trap}; kwargs...)
     result = Dict{Int, PulseqTrapezoid}()
     for line in section.content
@@ -444,6 +500,25 @@ function parse_section(section::PulseqSection{:trap}; kwargs...)
     return result
 end
 
+function gen_section(comp:: PulseqComponents, ::Val{:trap})
+    res = PulseqSection{:trap}(String[])
+    for (i, grad) in enumerate(comp.grads)
+        if !(grad isa PulseqTrapezoid)
+            continue
+        end
+        values = string.(Any[
+            i,
+            grad.amplitude,
+            grad.rise,
+            grad.flat,
+            grad.fall,
+            grad.delay,
+        ])
+        push!(res.content, join(values, " "))
+    end
+    return res
+end
+
 function parse_section(section::PulseqSection{:adc}; kwargs...)
     result = Dict{Int, PulseqADC}()
     for line in section.content
@@ -463,6 +538,22 @@ function parse_section(section::PulseqSection{:adc}; kwargs...)
     return result
 end
 
+function gen_section(comp:: PulseqComponents, ::Val{:adc})
+    res = PulseqSection{:adc}(String[])
+    for (i, adc) in enumerate(comp.adc)
+        values = string.(Any[
+            i,
+            adc.num,
+            adc.dwell,
+            adc.delay,
+            adc.frequency,
+            adc.phase,
+        ])
+        push!(res.content, join(values, " "))
+    end
+    return res
+end
+
 function parse_section(section::PulseqSection{:extensions}; kwargs...)
     current_extension = -1
     pre_amble = true
@@ -522,4 +613,58 @@ function parse_section(section::PulseqSection{:blocks}; version, rf=Dict(), grad
     return res
 end
 
+function gen_section(seq::PulseqSequence, comp:: PulseqComponents, ::Val{:blocks})
+    res = PulseqSection{:blocks}(String[])
+
+    for (i, block) in enumerate(seq.blocks)
+        values = Any[i, block.duration]
+        for (search_vec, part) in [
+            (comp.pulses, block.rf)
+            (comp.grads, block.gx) 
+            (comp.grads, block.gy) 
+            (comp.grads, block.gz) 
+            (comp.adc, block.adc)
+        ]
+            push!(values, add_components!(comp, search_vec, part))
+        end
+        push!(values, 0)
+        if length(block.ext) > 0
+            error("Cannot write extensions yet.")
+        end
+        push!(res.content, join(string.(values), " "))
+    end
+    return res
+end
+
+"""
+    read_pulseq(IO)
+
+Reads a sequence from a pulseq file (http://pulseq.github.io/).
+Pulseq files can be produced using matlab (http://pulseq.github.io/) or python (https://pypulseq.readthedocs.io/en/master/).
+"""
+function read_pulseq(io::IO)
+    sections = parse_pulseq_sections(io)
+    return parse_all_sections(sections)
+end
+
+"""
+    write_pulseq(IO, sequence)
+
+Writes a sequence to an output IO file.
+"""
+function write_pulseq(io::IO, sequence::PulseqSequence)
+    if sequence.version < v"1.4"
+        error("Can only write to pulseq version 1.4 or later.")
+    end
+    sections = gen_all_sections(sequence)
+    for key in [:version, :definitions, :blocks, :rf, :gradients, :trap, :adc, :shapes]
+        if length(sections[key].content) == 0
+            continue
+        end
+        write_pulseq_section(io, sections[key])
+    end
+end
+
+
+
 end
\ No newline at end of file
-- 
GitLab