From d76d3140138a67c34636b59180ced4e7ff73b1ae Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk> Date: Thu, 21 Mar 2024 18:12:58 +0000 Subject: [PATCH] Read pulseq files --- Project.toml | 2 + src/MRIBuilder.jl | 4 + src/components/pulses/generic_pulses.jl | 2 +- src/containers/base_sequences.jl | 4 +- src/pulseq.jl | 320 ++++++++++++++++++++++++ 5 files changed, 329 insertions(+), 3 deletions(-) create mode 100644 src/pulseq.jl diff --git a/Project.toml b/Project.toml index 83c65d6..fdb9322 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,8 @@ authors = ["Michiel Cottaar <Michiel.Cottaar@ndcn.ox.ac.uk>"] version = "0.0.0" [deps] +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" Juniper = "2ddba703-00a4-53a7-87a5-e8b9971dde84" diff --git a/src/MRIBuilder.jl b/src/MRIBuilder.jl index 5cc156c..520111e 100644 --- a/src/MRIBuilder.jl +++ b/src/MRIBuilder.jl @@ -12,6 +12,7 @@ include("pathways.jl") include("parts/parts.jl") include("sequences/sequences.jl") include("printing.jl") +include("pulseq.jl") include("plot.jl") import .BuildSequences: build_sequence, global_model, global_scanner, fixed @@ -38,6 +39,9 @@ export dwi_gradients, readout_event, excitation_pulse, refocus_pulse, Trapezoid, import .Sequences: GradientEcho, SpinEcho, DiffusionSpinEcho, DW_SE, DWI export GradientEcho, SpinEcho, DiffusionSpinEcho, DW_SE, DWI +import .Pulseq: read_pulseq +export read_pulseq + import .Plot: plot_sequence export plot_sequence diff --git a/src/components/pulses/generic_pulses.jl b/src/components/pulses/generic_pulses.jl index 8b9f72e..4b3c01d 100644 --- a/src/components/pulses/generic_pulses.jl +++ b/src/components/pulses/generic_pulses.jl @@ -49,7 +49,7 @@ flip_angle(pulse::GenericPulse) = sum(get_weights(pulse) .* pulse.amplitude) * 3 function time_halfway_flip(pulse::GenericPulse) w = get_weights(pulse) flip_so_far = cumsum(w .* pulse.amplitude) - return pulse.times[findfirst(f -> f >= flip_so_far[end] / 2, flip_so_far)] + return pulse.time[findfirst(f -> f >= flip_so_far[end] / 2, flip_so_far)] end for fn in (:amplitude, :phase) diff --git a/src/containers/base_sequences.jl b/src/containers/base_sequences.jl index f98075b..0bda22d 100644 --- a/src/containers/base_sequences.jl +++ b/src/containers/base_sequences.jl @@ -137,9 +137,9 @@ struct Sequence{S, N} <: BaseSequence{N} scanner :: Scanner end -function Sequence(blocks::AbstractVector; name=:Sequence, variables...) +function Sequence(blocks::AbstractVector; name=:Sequence, scanner=nothing, variables...) blocks = to_block_pair.(blocks) - res = Sequence{name, length(blocks)}(SVector{length(blocks)}(blocks), global_scanner()) + res = Sequence{name, length(blocks)}(SVector{length(blocks)}(blocks), isnothing(scanner) ? global_scanner() : scanner) set_simple_constraints!(res, variables) return res end diff --git a/src/pulseq.jl b/src/pulseq.jl new file mode 100644 index 0000000..a907bce --- /dev/null +++ b/src/pulseq.jl @@ -0,0 +1,320 @@ +module Pulseq +import ..Variables: duration +import ..Scanners: Scanner +import ..Components: GenericPulse, GradientWaveform, ADC +import ..Containers: BuildingBlock, Sequence, Wait +import DataStructures: OrderedDict +import Interpolations: linear_interpolation, Flat +import StaticArrays: SVector + +struct PulseqSection + title :: String + content :: Vector{String} +end + +struct CompressedPulseqShape + id :: Int + num :: Int + samples :: Vector{Float64} +end + + +""" + control_points(shape::CompressedPulseqShape) + +Returns a tuple with: +- vector of times +- vector of amplitudes +""" +function control_points(pulseq::CompressedPulseqShape) + compressed = length(pulseq.samples) != pulseq.num + if !compressed + times = range(0, 1, length=pulseq.num) + return (times, pulseq.samples) + end + times = [zero(Float64)] + amplitudes = [pulseq.samples[1]] + repeating = false + time_norm = 1 / (pulseq.num - 1) + prev_sample = pulseq.samples[1] + prev_applied = true + for sample in pulseq.samples[2:end] + if repeating + nrepeats = Int(sample) + 2 + if prev_applied + nrepeats -= 1 + end + push!(times, times[end] + nrepeats * time_norm) + push!(amplitudes, amplitudes[end] + nrepeats * prev_sample) + repeating = false + prev_applied = true + elseif sample == prev_sample + repeating = true + else + if !prev_applied + push!(times, times[end] + time_norm) + push!(amplitudes, amplitudes[end] + prev_sample) + end + prev_sample = sample + prev_applied = false + end + end + if !prev_applied + push!(times, times[end] + time_norm) + push!(amplitudes, amplitudes[end] + prev_sample) + end + return (times, amplitudes) +end + +""" + read_pulseq(filename; scanner=Scanner(B0=B0), B0=3., TR=<sequence duration>) + +Reads a sequence from a pulseq file (http://pulseq.github.io/). +Pulseq files can be produced using matlab (http://pulseq.github.io/) or python (https://pypulseq.readthedocs.io/en/master/). +""" +function read_pulseq(filename; kwargs...) + keywords = open(read_pulseq_sections, filename) + name = basename(filename) + if endswith(name, ".seq") + name = name[1:end-4] + end + build_sequence(; name=Symbol(name), kwargs..., keywords...) +end + +function read_pulseq_sections(io::IO) + sections = PulseqSection[] + version = nothing + title = "" + for line in readlines(io) + line = strip(line) + if length(line) == 0 || line[1] == '#' + continue # ignore comments + end + if line[1] == '[' && line[end] == ']' + if title == "VERSION" + version = parse_pulseq_section(sections[end]).second + end + # new section starts + title = line[2:end-1] + push!(sections, PulseqSection(title, String[])) + elseif length(sections) > 0 + push!(sections[end].content, line) + else + error("Content found in pulseq file before first section") + end + end + Dict(filter(pair -> !isnothing(pair), parse_pulseq_section.(sections, version))...) +end + +function parse_pulseq_ordered_dict(strings::Vector{<:AbstractString}, names, dtypes) + as_dict = OrderedDict{Int, NamedTuple{Tuple(names), Tuple{dtypes...}}}() + for line in strings + parts = split(line) + @assert length(parts) == length(names) + values = parse.(dtypes, split(line)) + @assert names[1] == :id + as_dict[values[1]] = (; zip(names, values)...) + end + return as_dict +end + +function parse_pulseq_properties(strings::Vector{<:AbstractString}) + result = Dict{String, Any}() + for s in strings + (name, value) = split(s, limit=2) + result[name] = value + end + return result +end + +section_headers = Dict( + "BLOCKS" => ([:id, :duration, :rf, :gx, :gy, :gz, :adc, :ext], fill(Int, 8)), + ("BLOCKS", v"1.3.1") => ([:id, :delay, :rf, :gx, :gy, :gz, :adc, :ext], fill(Int, 8)), + ("DELAYS", v"1.3.1") => ([:id, :delay], [Int, Int]), + ("RF", v"1.3.1") => ([:id, :amp, :mag_id, :phase_id, :delay, :freq, :phase], [Int, Float64, Int, Int, Int, Float64, Float64]), + "RF" => ([:id, :amp, :mag_id, :phase_id, :time_id, :delay, :freq, :phase], [Int, Float64, Int, Int, Int, Int, Float64, Float64]), + ("GRADIENTS", v"1.3.1") => ([:id, :amp, :shape_id, :delay], [Int, Float64, Int, Int]), + "GRADIENTS" => ([:id, :amp, :shape_id, :time_id, :delay], [Int, Float64, Int, Int, Int]), + "TRAP" => ([:id, :amp, :rise, :flat, :fall, :delay], [Int, Float64, Int, Int, Int, Int]), + "ADC" => ([:id, :num, :dwell, :delay, :freq, :phase], [Int, Int, Float64, Int, Float64, Float64]), +) + +function parse_pulseq_section(section::PulseqSection, version=nothing) + if section.title == "VERSION" + props = parse_pulseq_properties(section.content) + result = VersionNumber( + parse(Int, props["major"]), + parse(Int, props["minor"]), + parse(Int, props["revision"]), + ) + elseif section.title == "DEFINITIONS" + result = parse_pulseq_properties(section.content) + elseif (section.title, version) in keys(section_headers) + result = parse_pulseq_ordered_dict(section.content, section_headers[(section.title, version)]...) + elseif section.title in keys(section_headers) + result = parse_pulseq_ordered_dict(section.content, section_headers[section.title]...) + elseif section.title == "EXTENSION" + println("Ignoring all extensions in pulseq") + elseif section.title == "SHAPES" + current_id = -1 + shapes = CompressedPulseqShape[] + for line in section.content + if length(line) > 8 && lowercase(line[1:8]) == "shape_id" + current_id = parse(Int, line[9:end]) + continue + end + for text in ("num_uncompressed", "num_samples") + if startswith(lowercase(line), text) + @assert current_id != -1 + push!(shapes, CompressedPulseqShape(current_id, parse(Int, line[length(text)+1:end]), Float64[])) + current_id = -1 + break + end + end + if !startswith(lowercase(line), "num") + @assert current_id == -1 + push!(shapes[end].samples, parse(Float64, line)) + end + end + result = Dict( + [s.id => (s.num, control_points(s)...) for s in shapes]... + ) + elseif section.title in ["SIGNATURE"] + # silently ignore these sections + return nothing + else + error("Unrecognised pulseq section: $(section.title)") + end + return Symbol(lowercase(section.title)) => result +end + +function align_in_time(pairs...) + interps = [linear_interpolation(time, ampl; extrapolation_bc=Flat()) for (time, ampl) in pairs] + all_times = sort(unique(vcat([time for (time, _) in pairs]...))) + return (all_times, [interp.(all_times) for interp in interps]...) +end + + +function build_sequence(; scanner=nothing, B0=3., TR=nothing, definitions, version, blocks, rf=nothing, gradients=nothing, trap=nothing, delays=nothing, shapes=nothing, adc=nothing, name=:Pulseq) + if isnothing(scanner) + scanner = Scanner(B0=B0) + end + if version == v"1.4.0" + # load raster times (converting seconds to milliseconds) + convert = key -> parse(Float64, definitions[key]) * Float64(1e3) + gradient_raster = convert("GradientRasterTime") + rf_raster = convert("RadiofrequencyRasterTime") + adc_raster = convert("AdcRasterTime") + block_duration_raster = convert("BlockDurationRaster") + elseif version == v"1.3.1" + gradient_raster = rf_raster = block_duration_raster = Float64(1e-3) # 1 microsecond as default raster + adc_raster = Float64(1e-6) # ADC dwell time is in ns by default + else + error("Can only load pulseq files with versions v1.3.1 and v1.4.0, not $(version)") + end + + full_blocks = BuildingBlock[] + + for block in values(blocks) + block_duration = 0. + events = [] + if !iszero(block.rf) + proc = rf[block.rf] + (num, times_shape, amplitudes_shape) = shapes[proc.mag_id] + block_duration = max(num * rf_raster + proc.delay * 1e-3, block_duration) + ampl_times = times_shape .* (num * rf_raster) + ampl_size = amplitudes_shape .* (proc.amp * 1e-3) + if iszero(proc.phase_id) + phase_times = ampl_times + phase_size = ampl_times .* 0 + else + (num, times_shape, phase_shape) = shapes[proc.phase_id] + block_duration = max(num * rf_raster + proc.delay * 1e-3, block_duration) + phase_times = times_shape .* (num * rf_raster) + phase_size = rad2deg.(phase_shape) .+ phase_times .* (proc.freq * 1e-3 * 360) .+ rad2deg(proc.phase) + end + if version != v"1.3.1" && !iszero(proc.time_id) + (num, time_shape) = shapes[proc.time_id] + times = time_shape.amplitudes .* rf_raster + ampl = ampl_size + phase = phase_size + else + (times, ampl, phase) = align_in_time((ampl_times, ampl_size), (phase_times, phase_size)) + end + push!(events, (proc.delay * 1e-3, GenericPulse(times, ampl, phase))) + end + if !iszero(block.adc) + proc = adc[block.adc] + push!(events, (proc.delay, ADC(proc.num, proc.dwell * 1e-6, proc.dwell * proc.num * 1e-6 / 2, 0.))) + block_duration = max(proc.delay * 1e-3 + proc.dwell * proc.num * 1e-6, block_duration) + end + grad_shapes = [] + for symbol_grad in [:gx, :gy, :gz] + grad_id = getfield(block, symbol_grad) + if iszero(grad_id) + push!(grad_shapes, (Float64[], Float64[])) + continue + end + if !isnothing(gradients) && grad_id in keys(gradients) + proc = gradients[grad_id] + start_time = proc.delay * 1e-3 + (num, grad_shape) = shapes[proc.shape_id] + + if version != v"1.3.1" && !iszero(proc.time_id) + (num, time_shape) = shapes[proc.time_id] + times = time_shape.amplitudes .* gradient_raster .+ start_time + else + times = grad_shape.times .* (num * gradient_raster) .+ start_time + end + push!(grad_shapes, (times, grad_shape.amplitudes .* (proc.amp * 1e-9))) + block_duration = max(proc.delay * 1e-3 + num * gradient_raster, block_duration) + elseif !isnothing(trap) && grad_id in keys(trap) + proc = trap[grad_id] + start_time = proc.delay * 1e-3 + times = (cumsum([0, proc.rise, proc.flat, proc.fall]) .* 1e-3) .+ start_time + push!(grad_shapes, (times, [0, proc.amp * 1e-9, proc.amp * 1e-9, 0])) + block_duration = max((proc.delay + proc.rise + proc.flat + proc.fall) * 1e-3, block_duration) + else + error("Gradient ID $grad_id not found in either of [GRADIENTS] or [TRAP] sections") + end + end + if version == v"1.3.1" + if !iszero(block.delay) + actual_duration = max(block_duration, delays[block.delay].delay * 1e-3) + else + actual_duration = block_duration + end + else + actual_duration = block.duration * block_duration_raster + @assert actual_duration >= block_duration + end + function extend_grad!(pair) + (times, amplitudes) = pair + if iszero(length(times)) || !iszero(times[1]) + pushfirst!(times, 0.) + pushfirst!(amplitudes, 0.) + end + if times[end] != actual_duration + push!(times, actual_duration) + push!(amplitudes, 0.) + end + end + extend_grad!.(grad_shapes) + arrs = align_in_time(grad_shapes...) + waveform = [(t, SVector{3, Float64}(gx, gy, gz)) for (t, gx, gy, gz) in zip(arrs...)] + push!(full_blocks, BuildingBlock(waveform, events)) + end + full_block_duration = sum(duration.(full_blocks)) + if isnothing(TR) + total_duration = parse(Float64, get(definitions, "TotalDuration", "0")) * 1000 + TR = iszero(total_duration) ? full_block_duration : total_duration + end + if TR > full_block_duration + push!(full_blocks, Wait(TR - full_block_duration)) + elseif TR < full_block_duration + @warn "Given TR ($TR) is shorter than the total duration of all the building blocks ($full_block_duration). Ignoring TR..." + end + Sequence(full_blocks; scanner=scanner, name=name) +end + +end \ No newline at end of file -- GitLab