-
Michiel Cottaar authoredMichiel Cottaar authored
pulseq.jl 12.64 KiB
module Pulseq
import ...Variables: duration
import ...Scanners: Scanner
import ...Components: GenericPulse, GradientWaveform, ADC
import ...Containers: BuildingBlock, Sequence, Wait
import DataStructures: OrderedDict
import Interpolations: linear_interpolation, Flat
import StaticArrays: SVector
struct PulseqSection
title :: String
content :: Vector{String}
end
struct CompressedPulseqShape
id :: Int
num :: Int
samples :: Vector{Float64}
end
"""
control_points(shape::CompressedPulseqShape)
Returns a tuple with:
- vector of times
- vector of amplitudes
"""
function control_points(pulseq::CompressedPulseqShape)
compressed = length(pulseq.samples) != pulseq.num
if !compressed
times = range(0, 1, length=pulseq.num)
return (times, pulseq.samples)
end
times = [zero(Float64)]
amplitudes = [pulseq.samples[1]]
repeating = false
time_norm = 1 / (pulseq.num - 1)
prev_sample = pulseq.samples[1]
prev_applied = true
for sample in pulseq.samples[2:end]
if repeating
nrepeats = Int(sample) + 2
if prev_applied
nrepeats -= 1
end
push!(times, times[end] + nrepeats * time_norm)
push!(amplitudes, amplitudes[end] + nrepeats * prev_sample)
repeating = false
prev_applied = true
elseif sample == prev_sample
repeating = true
else
if !prev_applied
push!(times, times[end] + time_norm)
push!(amplitudes, amplitudes[end] + prev_sample)
end
prev_sample = sample
prev_applied = false
end
end
if !prev_applied
push!(times, times[end] + time_norm)
push!(amplitudes, amplitudes[end] + prev_sample)
end
return (times, amplitudes)
end
"""
read_pulseq(filename; scanner=Scanner(B0=B0), B0=3., TR=<sequence duration>)
Reads a sequence from a pulseq file (http://pulseq.github.io/).
Pulseq files can be produced using matlab (http://pulseq.github.io/) or python (https://pypulseq.readthedocs.io/en/master/).
"""
function read_pulseq(filename; kwargs...)
keywords = open(read_pulseq_sections, filename)
name = basename(filename)
if endswith(name, ".seq")
name = name[1:end-4]
end
build_sequence(; name=Symbol(name), kwargs..., keywords...)
end
function read_pulseq_sections(io::IO)
sections = PulseqSection[]
version = nothing
title = ""
for line in readlines(io)
line = strip(line)
if length(line) == 0 || line[1] == '#'
continue # ignore comments
end
if line[1] == '[' && line[end] == ']'
if title == "VERSION"
version = parse_pulseq_section(sections[end]).second
end
# new section starts
title = line[2:end-1]
push!(sections, PulseqSection(title, String[]))
elseif length(sections) > 0
push!(sections[end].content, line)
else
error("Content found in pulseq file before first section")
end
end
Dict(filter(pair -> !isnothing(pair), parse_pulseq_section.(sections, version))...)
end
function parse_pulseq_ordered_dict(strings::Vector{<:AbstractString}, names, dtypes)
as_dict = OrderedDict{Int, NamedTuple{Tuple(names), Tuple{dtypes...}}}()
for line in strings
parts = split(line)
@assert length(parts) == length(names)
values = parse.(dtypes, split(line))
@assert names[1] == :id
as_dict[values[1]] = (; zip(names, values)...)
end
return as_dict
end
function parse_pulseq_properties(strings::Vector{<:AbstractString})
result = Dict{String, Any}()
for s in strings
(name, value) = split(s, limit=2)
result[name] = value
end
return result
end
section_headers = Dict(
"BLOCKS" => ([:id, :duration, :rf, :gx, :gy, :gz, :adc, :ext], fill(Int, 8)),
("BLOCKS", v"1.3.1") => ([:id, :delay, :rf, :gx, :gy, :gz, :adc, :ext], fill(Int, 8)),
("DELAYS", v"1.3.1") => ([:id, :delay], [Int, Int]),
("RF", v"1.3.1") => ([:id, :amp, :mag_id, :phase_id, :delay, :freq, :phase], [Int, Float64, Int, Int, Int, Float64, Float64]),
"RF" => ([:id, :amp, :mag_id, :phase_id, :time_id, :delay, :freq, :phase], [Int, Float64, Int, Int, Int, Int, Float64, Float64]),
("GRADIENTS", v"1.3.1") => ([:id, :amp, :shape_id, :delay], [Int, Float64, Int, Int]),
"GRADIENTS" => ([:id, :amp, :shape_id, :time_id, :delay], [Int, Float64, Int, Int, Int]),
"TRAP" => ([:id, :amp, :rise, :flat, :fall, :delay], [Int, Float64, Int, Int, Int, Int]),
"ADC" => ([:id, :num, :dwell, :delay, :freq, :phase], [Int, Int, Float64, Int, Float64, Float64]),
)
function parse_pulseq_section(section::PulseqSection, version=nothing)
if section.title == "VERSION"
props = parse_pulseq_properties(section.content)
result = VersionNumber(
parse(Int, props["major"]),
parse(Int, props["minor"]),
parse(Int, props["revision"]),
)
elseif section.title == "DEFINITIONS"
result = parse_pulseq_properties(section.content)
elseif (section.title, version) in keys(section_headers)
result = parse_pulseq_ordered_dict(section.content, section_headers[(section.title, version)]...)
elseif section.title in keys(section_headers)
result = parse_pulseq_ordered_dict(section.content, section_headers[section.title]...)
elseif section.title == "EXTENSION"
println("Ignoring all extensions in pulseq")
elseif section.title == "SHAPES"
current_id = -1
shapes = CompressedPulseqShape[]
for line in section.content
if length(line) > 8 && lowercase(line[1:8]) == "shape_id"
current_id = parse(Int, line[9:end])
continue
end
for text in ("num_uncompressed", "num_samples")
if startswith(lowercase(line), text)
@assert current_id != -1
push!(shapes, CompressedPulseqShape(current_id, parse(Int, line[length(text)+1:end]), Float64[]))
current_id = -1
break
end
end
if !startswith(lowercase(line), "num")
@assert current_id == -1
push!(shapes[end].samples, parse(Float64, line))
end
end
result = Dict(
[s.id => (s.num, control_points(s)...) for s in shapes]...
)
elseif section.title in ["SIGNATURE"]
# silently ignore these sections
return nothing
else
error("Unrecognised pulseq section: $(section.title)")
end
return Symbol(lowercase(section.title)) => result
end
function align_in_time(pairs...)
interps = [linear_interpolation(time, ampl; extrapolation_bc=Flat()) for (time, ampl) in pairs]
all_times = sort(unique(vcat([time for (time, _) in pairs]...)))
return (all_times, [interp.(all_times) for interp in interps]...)
end
function build_sequence(; scanner=nothing, B0=3., TR=nothing, definitions, version, blocks, rf=nothing, gradients=nothing, trap=nothing, delays=nothing, shapes=nothing, adc=nothing, name=:Pulseq)
if isnothing(scanner)
scanner = Scanner(B0=B0)
end
if version == v"1.4.0"
# load raster times (converting seconds to milliseconds)
convert = key -> parse(Float64, definitions[key]) * Float64(1e3)
gradient_raster = convert("GradientRasterTime")
rf_raster = convert("RadiofrequencyRasterTime")
adc_raster = convert("AdcRasterTime")
block_duration_raster = convert("BlockDurationRaster")
elseif version == v"1.3.1"
gradient_raster = rf_raster = block_duration_raster = Float64(1e-3) # 1 microsecond as default raster
adc_raster = Float64(1e-6) # ADC dwell time is in ns by default
else
error("Can only load pulseq files with versions v1.3.1 and v1.4.0, not $(version)")
end
full_blocks = BuildingBlock[]
for block in values(blocks)
block_duration = 0.
events = []
if !iszero(block.rf)
proc = rf[block.rf]
(num, times_shape, amplitudes_shape) = shapes[proc.mag_id]
block_duration = max(num * rf_raster + proc.delay * 1e-3, block_duration)
ampl_times = times_shape .* (num * rf_raster)
ampl_size = amplitudes_shape .* (proc.amp * 1e-3)
if iszero(proc.phase_id)
phase_times = ampl_times
phase_size = ampl_times .* 0
else
(num, times_shape, phase_shape) = shapes[proc.phase_id]
block_duration = max(num * rf_raster + proc.delay * 1e-3, block_duration)
phase_times = times_shape .* (num * rf_raster)
phase_size = rad2deg.(phase_shape) .+ phase_times .* (proc.freq * 1e-3 * 360) .+ rad2deg(proc.phase)
end
if version != v"1.3.1" && !iszero(proc.time_id)
(num, time_shape) = shapes[proc.time_id]
times = time_shape.amplitudes .* rf_raster
ampl = ampl_size
phase = phase_size
else
(times, ampl, phase) = align_in_time((ampl_times, ampl_size), (phase_times, phase_size))
end
push!(events, (proc.delay * 1e-3, GenericPulse(times, ampl, phase)))
end
if !iszero(block.adc)
proc = adc[block.adc]
push!(events, (proc.delay * 1e-3, ADC(proc.num, proc.dwell * 1e-6, proc.dwell * proc.num * 1e-6 / 2, 1.)))
block_duration = max(proc.delay * 1e-3 + proc.dwell * proc.num * 1e-6, block_duration)
end
grad_shapes = []
for symbol_grad in [:gx, :gy, :gz]
grad_id = getfield(block, symbol_grad)
if iszero(grad_id)
push!(grad_shapes, (Float64[], Float64[]))
continue
end
if !isnothing(gradients) && grad_id in keys(gradients)
proc = gradients[grad_id]
start_time = proc.delay * 1e-3
(num, grad_shape) = shapes[proc.shape_id]
if version != v"1.3.1" && !iszero(proc.time_id)
(num, time_shape) = shapes[proc.time_id]
times = time_shape.amplitudes .* gradient_raster .+ start_time
else
times = grad_shape.times .* (num * gradient_raster) .+ start_time
end
push!(grad_shapes, (times, grad_shape.amplitudes .* (proc.amp * 1e-9)))
block_duration = max(start_time + num * gradient_raster, block_duration)
elseif !isnothing(trap) && grad_id in keys(trap)
proc = trap[grad_id]
start_time = proc.delay * 1e-3
times = (cumsum([0, proc.rise, proc.flat, proc.fall]) .* 1e-3) .+ start_time
push!(grad_shapes, (times, [0, proc.amp * 1e-9, proc.amp * 1e-9, 0]))
block_duration = max((start_time + proc.rise + proc.flat + proc.fall) * 1e-3, block_duration)
else
error("Gradient ID $grad_id not found in either of [GRADIENTS] or [TRAP] sections")
end
end
if version == v"1.3.1"
if !iszero(block.delay)
actual_duration = max(block_duration, delays[block.delay].delay * 1e-3)
else
actual_duration = block_duration
end
else
actual_duration = block.duration * block_duration_raster
@assert actual_duration >= block_duration
end
function extend_grad!(pair)
(times, amplitudes) = pair
if iszero(length(times)) || !iszero(times[1])
pushfirst!(times, 0.)
pushfirst!(amplitudes, 0.)
end
if times[end] != actual_duration
push!(times, actual_duration)
push!(amplitudes, 0.)
end
end
extend_grad!.(grad_shapes)
arrs = align_in_time(grad_shapes...)
waveform = [(t, SVector{3, Float64}(gx, gy, gz)) for (t, gx, gy, gz) in zip(arrs...)]
push!(full_blocks, BuildingBlock(waveform, events))
end
full_block_duration = sum(duration.(full_blocks))
if isnothing(TR)
total_duration = parse(Float64, get(definitions, "TotalDuration", "0")) * 1000
TR = iszero(total_duration) ? full_block_duration : total_duration
end
if TR > full_block_duration
push!(full_blocks, Wait(TR - full_block_duration))
elseif TR < full_block_duration
@warn "Given TR ($TR) is shorter than the total duration of all the building blocks ($full_block_duration). Ignoring TR..."
end
Sequence(full_blocks; scanner=scanner, name=name)
end
end