Skip to content
Snippets Groups Projects
Verified Commit d76d3140 authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

Read pulseq files

parent 0bfa5869
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,8 @@ authors = ["Michiel Cottaar <Michiel.Cottaar@ndcn.ox.ac.uk>"]
version = "0.0.0"
[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Juniper = "2ddba703-00a4-53a7-87a5-e8b9971dde84"
......
......@@ -12,6 +12,7 @@ include("pathways.jl")
include("parts/parts.jl")
include("sequences/sequences.jl")
include("printing.jl")
include("pulseq.jl")
include("plot.jl")
import .BuildSequences: build_sequence, global_model, global_scanner, fixed
......@@ -38,6 +39,9 @@ export dwi_gradients, readout_event, excitation_pulse, refocus_pulse, Trapezoid,
import .Sequences: GradientEcho, SpinEcho, DiffusionSpinEcho, DW_SE, DWI
export GradientEcho, SpinEcho, DiffusionSpinEcho, DW_SE, DWI
import .Pulseq: read_pulseq
export read_pulseq
import .Plot: plot_sequence
export plot_sequence
......
......@@ -49,7 +49,7 @@ flip_angle(pulse::GenericPulse) = sum(get_weights(pulse) .* pulse.amplitude) * 3
function time_halfway_flip(pulse::GenericPulse)
w = get_weights(pulse)
flip_so_far = cumsum(w .* pulse.amplitude)
return pulse.times[findfirst(f -> f >= flip_so_far[end] / 2, flip_so_far)]
return pulse.time[findfirst(f -> f >= flip_so_far[end] / 2, flip_so_far)]
end
for fn in (:amplitude, :phase)
......
......@@ -137,9 +137,9 @@ struct Sequence{S, N} <: BaseSequence{N}
scanner :: Scanner
end
function Sequence(blocks::AbstractVector; name=:Sequence, variables...)
function Sequence(blocks::AbstractVector; name=:Sequence, scanner=nothing, variables...)
blocks = to_block_pair.(blocks)
res = Sequence{name, length(blocks)}(SVector{length(blocks)}(blocks), global_scanner())
res = Sequence{name, length(blocks)}(SVector{length(blocks)}(blocks), isnothing(scanner) ? global_scanner() : scanner)
set_simple_constraints!(res, variables)
return res
end
......
module Pulseq
import ..Variables: duration
import ..Scanners: Scanner
import ..Components: GenericPulse, GradientWaveform, ADC
import ..Containers: BuildingBlock, Sequence, Wait
import DataStructures: OrderedDict
import Interpolations: linear_interpolation, Flat
import StaticArrays: SVector
struct PulseqSection
title :: String
content :: Vector{String}
end
struct CompressedPulseqShape
id :: Int
num :: Int
samples :: Vector{Float64}
end
"""
control_points(shape::CompressedPulseqShape)
Returns a tuple with:
- vector of times
- vector of amplitudes
"""
function control_points(pulseq::CompressedPulseqShape)
compressed = length(pulseq.samples) != pulseq.num
if !compressed
times = range(0, 1, length=pulseq.num)
return (times, pulseq.samples)
end
times = [zero(Float64)]
amplitudes = [pulseq.samples[1]]
repeating = false
time_norm = 1 / (pulseq.num - 1)
prev_sample = pulseq.samples[1]
prev_applied = true
for sample in pulseq.samples[2:end]
if repeating
nrepeats = Int(sample) + 2
if prev_applied
nrepeats -= 1
end
push!(times, times[end] + nrepeats * time_norm)
push!(amplitudes, amplitudes[end] + nrepeats * prev_sample)
repeating = false
prev_applied = true
elseif sample == prev_sample
repeating = true
else
if !prev_applied
push!(times, times[end] + time_norm)
push!(amplitudes, amplitudes[end] + prev_sample)
end
prev_sample = sample
prev_applied = false
end
end
if !prev_applied
push!(times, times[end] + time_norm)
push!(amplitudes, amplitudes[end] + prev_sample)
end
return (times, amplitudes)
end
"""
read_pulseq(filename; scanner=Scanner(B0=B0), B0=3., TR=<sequence duration>)
Reads a sequence from a pulseq file (http://pulseq.github.io/).
Pulseq files can be produced using matlab (http://pulseq.github.io/) or python (https://pypulseq.readthedocs.io/en/master/).
"""
function read_pulseq(filename; kwargs...)
keywords = open(read_pulseq_sections, filename)
name = basename(filename)
if endswith(name, ".seq")
name = name[1:end-4]
end
build_sequence(; name=Symbol(name), kwargs..., keywords...)
end
function read_pulseq_sections(io::IO)
sections = PulseqSection[]
version = nothing
title = ""
for line in readlines(io)
line = strip(line)
if length(line) == 0 || line[1] == '#'
continue # ignore comments
end
if line[1] == '[' && line[end] == ']'
if title == "VERSION"
version = parse_pulseq_section(sections[end]).second
end
# new section starts
title = line[2:end-1]
push!(sections, PulseqSection(title, String[]))
elseif length(sections) > 0
push!(sections[end].content, line)
else
error("Content found in pulseq file before first section")
end
end
Dict(filter(pair -> !isnothing(pair), parse_pulseq_section.(sections, version))...)
end
function parse_pulseq_ordered_dict(strings::Vector{<:AbstractString}, names, dtypes)
as_dict = OrderedDict{Int, NamedTuple{Tuple(names), Tuple{dtypes...}}}()
for line in strings
parts = split(line)
@assert length(parts) == length(names)
values = parse.(dtypes, split(line))
@assert names[1] == :id
as_dict[values[1]] = (; zip(names, values)...)
end
return as_dict
end
function parse_pulseq_properties(strings::Vector{<:AbstractString})
result = Dict{String, Any}()
for s in strings
(name, value) = split(s, limit=2)
result[name] = value
end
return result
end
section_headers = Dict(
"BLOCKS" => ([:id, :duration, :rf, :gx, :gy, :gz, :adc, :ext], fill(Int, 8)),
("BLOCKS", v"1.3.1") => ([:id, :delay, :rf, :gx, :gy, :gz, :adc, :ext], fill(Int, 8)),
("DELAYS", v"1.3.1") => ([:id, :delay], [Int, Int]),
("RF", v"1.3.1") => ([:id, :amp, :mag_id, :phase_id, :delay, :freq, :phase], [Int, Float64, Int, Int, Int, Float64, Float64]),
"RF" => ([:id, :amp, :mag_id, :phase_id, :time_id, :delay, :freq, :phase], [Int, Float64, Int, Int, Int, Int, Float64, Float64]),
("GRADIENTS", v"1.3.1") => ([:id, :amp, :shape_id, :delay], [Int, Float64, Int, Int]),
"GRADIENTS" => ([:id, :amp, :shape_id, :time_id, :delay], [Int, Float64, Int, Int, Int]),
"TRAP" => ([:id, :amp, :rise, :flat, :fall, :delay], [Int, Float64, Int, Int, Int, Int]),
"ADC" => ([:id, :num, :dwell, :delay, :freq, :phase], [Int, Int, Float64, Int, Float64, Float64]),
)
function parse_pulseq_section(section::PulseqSection, version=nothing)
if section.title == "VERSION"
props = parse_pulseq_properties(section.content)
result = VersionNumber(
parse(Int, props["major"]),
parse(Int, props["minor"]),
parse(Int, props["revision"]),
)
elseif section.title == "DEFINITIONS"
result = parse_pulseq_properties(section.content)
elseif (section.title, version) in keys(section_headers)
result = parse_pulseq_ordered_dict(section.content, section_headers[(section.title, version)]...)
elseif section.title in keys(section_headers)
result = parse_pulseq_ordered_dict(section.content, section_headers[section.title]...)
elseif section.title == "EXTENSION"
println("Ignoring all extensions in pulseq")
elseif section.title == "SHAPES"
current_id = -1
shapes = CompressedPulseqShape[]
for line in section.content
if length(line) > 8 && lowercase(line[1:8]) == "shape_id"
current_id = parse(Int, line[9:end])
continue
end
for text in ("num_uncompressed", "num_samples")
if startswith(lowercase(line), text)
@assert current_id != -1
push!(shapes, CompressedPulseqShape(current_id, parse(Int, line[length(text)+1:end]), Float64[]))
current_id = -1
break
end
end
if !startswith(lowercase(line), "num")
@assert current_id == -1
push!(shapes[end].samples, parse(Float64, line))
end
end
result = Dict(
[s.id => (s.num, control_points(s)...) for s in shapes]...
)
elseif section.title in ["SIGNATURE"]
# silently ignore these sections
return nothing
else
error("Unrecognised pulseq section: $(section.title)")
end
return Symbol(lowercase(section.title)) => result
end
function align_in_time(pairs...)
interps = [linear_interpolation(time, ampl; extrapolation_bc=Flat()) for (time, ampl) in pairs]
all_times = sort(unique(vcat([time for (time, _) in pairs]...)))
return (all_times, [interp.(all_times) for interp in interps]...)
end
function build_sequence(; scanner=nothing, B0=3., TR=nothing, definitions, version, blocks, rf=nothing, gradients=nothing, trap=nothing, delays=nothing, shapes=nothing, adc=nothing, name=:Pulseq)
if isnothing(scanner)
scanner = Scanner(B0=B0)
end
if version == v"1.4.0"
# load raster times (converting seconds to milliseconds)
convert = key -> parse(Float64, definitions[key]) * Float64(1e3)
gradient_raster = convert("GradientRasterTime")
rf_raster = convert("RadiofrequencyRasterTime")
adc_raster = convert("AdcRasterTime")
block_duration_raster = convert("BlockDurationRaster")
elseif version == v"1.3.1"
gradient_raster = rf_raster = block_duration_raster = Float64(1e-3) # 1 microsecond as default raster
adc_raster = Float64(1e-6) # ADC dwell time is in ns by default
else
error("Can only load pulseq files with versions v1.3.1 and v1.4.0, not $(version)")
end
full_blocks = BuildingBlock[]
for block in values(blocks)
block_duration = 0.
events = []
if !iszero(block.rf)
proc = rf[block.rf]
(num, times_shape, amplitudes_shape) = shapes[proc.mag_id]
block_duration = max(num * rf_raster + proc.delay * 1e-3, block_duration)
ampl_times = times_shape .* (num * rf_raster)
ampl_size = amplitudes_shape .* (proc.amp * 1e-3)
if iszero(proc.phase_id)
phase_times = ampl_times
phase_size = ampl_times .* 0
else
(num, times_shape, phase_shape) = shapes[proc.phase_id]
block_duration = max(num * rf_raster + proc.delay * 1e-3, block_duration)
phase_times = times_shape .* (num * rf_raster)
phase_size = rad2deg.(phase_shape) .+ phase_times .* (proc.freq * 1e-3 * 360) .+ rad2deg(proc.phase)
end
if version != v"1.3.1" && !iszero(proc.time_id)
(num, time_shape) = shapes[proc.time_id]
times = time_shape.amplitudes .* rf_raster
ampl = ampl_size
phase = phase_size
else
(times, ampl, phase) = align_in_time((ampl_times, ampl_size), (phase_times, phase_size))
end
push!(events, (proc.delay * 1e-3, GenericPulse(times, ampl, phase)))
end
if !iszero(block.adc)
proc = adc[block.adc]
push!(events, (proc.delay, ADC(proc.num, proc.dwell * 1e-6, proc.dwell * proc.num * 1e-6 / 2, 0.)))
block_duration = max(proc.delay * 1e-3 + proc.dwell * proc.num * 1e-6, block_duration)
end
grad_shapes = []
for symbol_grad in [:gx, :gy, :gz]
grad_id = getfield(block, symbol_grad)
if iszero(grad_id)
push!(grad_shapes, (Float64[], Float64[]))
continue
end
if !isnothing(gradients) && grad_id in keys(gradients)
proc = gradients[grad_id]
start_time = proc.delay * 1e-3
(num, grad_shape) = shapes[proc.shape_id]
if version != v"1.3.1" && !iszero(proc.time_id)
(num, time_shape) = shapes[proc.time_id]
times = time_shape.amplitudes .* gradient_raster .+ start_time
else
times = grad_shape.times .* (num * gradient_raster) .+ start_time
end
push!(grad_shapes, (times, grad_shape.amplitudes .* (proc.amp * 1e-9)))
block_duration = max(proc.delay * 1e-3 + num * gradient_raster, block_duration)
elseif !isnothing(trap) && grad_id in keys(trap)
proc = trap[grad_id]
start_time = proc.delay * 1e-3
times = (cumsum([0, proc.rise, proc.flat, proc.fall]) .* 1e-3) .+ start_time
push!(grad_shapes, (times, [0, proc.amp * 1e-9, proc.amp * 1e-9, 0]))
block_duration = max((proc.delay + proc.rise + proc.flat + proc.fall) * 1e-3, block_duration)
else
error("Gradient ID $grad_id not found in either of [GRADIENTS] or [TRAP] sections")
end
end
if version == v"1.3.1"
if !iszero(block.delay)
actual_duration = max(block_duration, delays[block.delay].delay * 1e-3)
else
actual_duration = block_duration
end
else
actual_duration = block.duration * block_duration_raster
@assert actual_duration >= block_duration
end
function extend_grad!(pair)
(times, amplitudes) = pair
if iszero(length(times)) || !iszero(times[1])
pushfirst!(times, 0.)
pushfirst!(amplitudes, 0.)
end
if times[end] != actual_duration
push!(times, actual_duration)
push!(amplitudes, 0.)
end
end
extend_grad!.(grad_shapes)
arrs = align_in_time(grad_shapes...)
waveform = [(t, SVector{3, Float64}(gx, gy, gz)) for (t, gx, gy, gz) in zip(arrs...)]
push!(full_blocks, BuildingBlock(waveform, events))
end
full_block_duration = sum(duration.(full_blocks))
if isnothing(TR)
total_duration = parse(Float64, get(definitions, "TotalDuration", "0")) * 1000
TR = iszero(total_duration) ? full_block_duration : total_duration
end
if TR > full_block_duration
push!(full_blocks, Wait(TR - full_block_duration))
elseif TR < full_block_duration
@warn "Given TR ($TR) is shorter than the total duration of all the building blocks ($full_block_duration). Ignoring TR..."
end
Sequence(full_blocks; scanner=scanner, name=name)
end
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment