Commit 068f71a3 authored by Mo Shahdloo's avatar Mo Shahdloo
Browse files

Merge branch 'master' of https://git.fmrib.ox.ac.uk/shahdloo/pymapvbvd

Conflicts:
	mapVBVD/mapVBVD.py
	mapVBVD/read_twix_hdr.py
	mapVBVD/twix_map_obj.py
	requirements.txt
parents 8ffa4187 5863e813
......@@ -9,32 +9,42 @@ import logging
class twix_map_obj:
@property
def filename(self):
return self.fname
@property
def rampSampTrj(self):
return self.rstrj
@property
def dataType(self):
return self.dType
@property
def fullSize(self):
return self.fsize
if self.full_size is None:
self.clean()
return self.full_size
# @fullSize.setter
# def fullSize(self, val):
# self.fsize = val
# self.full_size = val
@property
def dataSize(self):
# self.clean()
# print(self.fsize)
out = self.fsize
if out is None:
self.clean()
out = self.fsize
# Not yet implemented
out = self.fullSize.copy()
if self.removeOS:
ix = self.dataDims.index('Col')
out[ix] = self.NCol / 2
if self.flagAverageDim[0] | self.flagAverageDim[1]:
if self.average_dim[0] | self.average_dim[1]:
print('averaging in col and cha dim not supported, resetting flag')
self.flagAverageDim[0:2] = False
self.average_dim[0:2] = False
out[self.flagAverageDim] = 1
out[self.average_dim] = 1
return out
@property
......@@ -58,49 +68,57 @@ class twix_map_obj:
def flagRemoveOS(self, removeOS):
self.removeOS = removeOS
@property
def flagAverageDim(self):
return self.average_dim
@flagAverageDim.setter
def flagAverageDim(self, val):
self.average_dim = val
@property
def flagDoAverage(self):
ix = self.dataDims.index('Ave')
return self.flagAverageDim[ix]
return self.average_dim[ix]
@flagDoAverage.setter
def flagDoAverage(self, bval):
ix = self.dataDims.index('Ave')
self.flagAverageDim[ix] = bval
self.average_dim[ix] = bval
@property
def flagAverageReps(self):
ix = self.dataDims.index('Rep')
return self.flagAverageDim[ix]
return self.average_dim[ix]
@flagAverageReps.setter
def flagAverageReps(self, bval):
ix = self.dataDims.index('Rep')
self.flagAverageDim[ix] = bval
self.average_dim[ix] = bval
@property
def flagAverageSets(self):
ix = self.dataDims.index('Set')
return self.flagAverageDim[ix]
return self.average_dim[ix]
@flagAverageSets.setter
def flagAverageSets(self, bval):
ix = self.dataDims.index('Set')
self.flagAverageDim[ix] = bval
self.average_dim[ix] = bval
@property
def flagIgnoreSeg(self):
ix = self.dataDims.index('Seg')
return self.flagAverageDim[ix]
return self.average_dim[ix]
@flagIgnoreSeg.setter
def flagIgnoreSeg(self, bval):
ix = self.dataDims.index('Seg')
self.flagAverageDim[ix] = bval
self.average_dim[ix] = bval
@property
def flagSkipToFirstLine(self):
return self.flagSkipToFirstLine
return self.skipToFirstLine
@flagSkipToFirstLine.setter
def flagSkipToFirstLine(self, bval):
......@@ -114,8 +132,8 @@ class twix_map_obj:
self.skipLin = 0
self.skipPar = 0
self.fsize[2] = np.maximum(1, self.NLin - self.skipLin)
self.fsize[3] = np.maximum(1, self.NPar - self.skipPar)
self.full_size[2] = np.maximum(1, self.NLin - self.skipLin)
self.full_size[3] = np.maximum(1, self.NPar - self.skipPar)
@property
def flagRampSampRegrid(self):
......@@ -123,11 +141,29 @@ class twix_map_obj:
@flagRampSampRegrid.setter
def flagRampSampRegrid(self, bval):
if bval and self.rampSampTrj is None:
if bval and self.rstrj is None:
raise Exception('No trajectory for regridding available')
self.regrid = bval
# TODO: flagDoRawDataCorrect, RawDataCorrectionFactors
# TODO: flagDoRawDataCorrect
@property
def flagDoRawDataCorrect(self):
return False
@flagDoRawDataCorrect.setter
def flagDoRawDataCorrect(self, bval):
pass
# TODO: RawDataCorrectionFactors
@property
def RawDataCorrectionFactors(self):
return []
@RawDataCorrectionFactors.setter
def RawDataCorrectionFactors(self, bval):
pass
def __init__(self, dataType, fname, version, rstraj=None, **kwargs):
self.ignoreROoffcenter = kwargs.get('ignoreROoffcenter', False)
......@@ -139,8 +175,8 @@ class twix_map_obj:
self.ignoreSeg = kwargs.get('ignoreSeg', False)
self.squeeze = kwargs.get('squeeze', False)
self.dataType = dataType.lower()
self.filename = fname
self.dType = dataType.lower()
self.fname = fname
self.softwareVersion = version
# self.IsReflected = logical([]);
......@@ -169,7 +205,7 @@ class twix_map_obj:
else:
raise ValueError('software version not supported')
self.rampSampTrj = rstraj
self.rstrj = rstraj
if rstraj is None:
self.regrid = False
......@@ -225,26 +261,26 @@ class twix_map_obj:
self.skipLin = None
self.skipPar = None
self.fsize = None
self.full_size = None
# Flags
self.flagAverageDim = np.full(16, False, dtype=np.bool)
self.flagAverageDim[self.dataDims.index('Ave')] = self.doAverage
self.flagAverageDim[self.dataDims.index('Rep')] = self.averageReps
self.flagAverageDim[self.dataDims.index('Set')] = self.averageSets
self.flagAverageDim[self.dataDims.index('Seg')] = self.ignoreSeg
self.average_dim = np.full(16, False, dtype=np.bool)
self.average_dim[self.dataDims.index('Ave')] = self.doAverage
self.average_dim[self.dataDims.index('Rep')] = self.averageReps
self.average_dim[self.dataDims.index('Set')] = self.averageSets
self.average_dim[self.dataDims.index('Seg')] = self.ignoreSeg
if self.dataType == 'image' or self.dataType == 'phasestab':
if self.dType == 'image' or self.dType == 'phasestab':
self.skipToFirstLine = False
else:
self.skipToFirstLine = True
def __str__(self):
des_str = ('***twix_map_obj***\n'
f'File: {self.filename}\n'
f'File: {self.fname}\n'
f'Software: {self.softwareVersion}\n'
f'Number of acquisitions read {self.NAcq}\n'
f'Data size is {np.array2string(self.fsize, formatter={"float": lambda x: "%.0f" % x}, separator=",")}\n'
f'Data size is {np.array2string(self.fullSize, formatter={"float": lambda x: "%.0f" % x}, separator=",")}\n'
f'Squeezed data size is {np.array2string(self.sqzSize, formatter={"int": lambda x: "%i" % x}, separator=",")} ({self.sqzDims})\n'
f'NCol = {self.NCol:0.0f}\n'
f'NCha = {self.NCha:0.0f}\n'
......@@ -326,7 +362,7 @@ class twix_map_obj:
while not isLastAcqGood and self.NAcq > 0 and cnt < 100:
try:
self.clean()
# self.clean()
self.unsorted(self.NAcq)
isLastAcqGood = True
except Exception as e:
......@@ -383,7 +419,7 @@ class twix_map_obj:
if self.NCha.ndim > 0:
self.NCha = self.NCha[0]
if self.dataType == 'refscan':
if self.dType == 'refscan':
# pehses: check for lines with 'negative' line/partition numbers
# this can happen when the reference scan line/partition range
# exceeds the one of the actual imaging scan
......@@ -412,7 +448,7 @@ class twix_map_obj:
NLinAlloc = np.maximum(1, self.NLin - self.skipLin)
NParAlloc = np.maximum(1, self.NPar - self.skipPar)
self.fsize = np.array(
self.full_size = np.array(
[self.NCol, self.NCha, NLinAlloc, NParAlloc,
self.NSli, self.NAve, self.NPhs, self.NEco,
self.NRep, self.NSet, self.NSeg, self.NIda,
......@@ -480,10 +516,10 @@ class twix_map_obj:
selRangeSz[idx] = k.size
# now select all indices for the dims that are averaged
for iDx, k in enumerate(np.nditer(self.flagAverageDim)):
for iDx, k in enumerate(np.nditer(self.average_dim)):
if k:
self.clean()
selRange[iDx] = np.arange(0, self.fsize[iDx])
selRange[iDx] = np.arange(0, self.fullSize[iDx])
return selRange, selRangeSz, outSize
......@@ -491,7 +527,7 @@ class twix_map_obj:
# calculate indices to target & source(raw)
LinIx = self.Lin - self.skipLin
ParIx = self.Par - self.skipPar
sz = self.fsize[2:]
sz = self.fullSize[2:]
ixToTarget = np.zeros(LinIx.size, dtype=int)
for i, _ in enumerate(ixToTarget):
......@@ -505,7 +541,7 @@ class twix_map_obj:
# now calc. inverse index (page table: virtual to physical addresses)
# indices of lines that are not measured are zero
ixToRaw = np.full(self.fsize[2:].prod().astype(int), np.nan, dtype=int)
ixToRaw = np.full(self.fullSize[2:].prod().astype(int), np.nan, dtype=int)
for i, itt in enumerate(ixToTarget):
ixToRaw[itt] = i
......@@ -537,7 +573,7 @@ class twix_map_obj:
# a property - slower but safer (and easier to keep track of updates)
ixToRaw, _ = self.calcIndices()
tmp = np.arange(0, self.fsize[2:].prod().astype(int)).reshape(self.fsize[2:].astype(int))
tmp = np.arange(0, self.fullSize[2:].prod().astype(int)).reshape(self.fullSize[2:].astype(int))
# tmpSelRange = [x-1 for x in selRange] # python indexing from 0
for i, ids in enumerate(selRange[2:]):
tmp = np.take(tmp, ids.astype(int), i)
......@@ -556,33 +592,33 @@ class twix_map_obj:
# calculate ixToTarg for possibly smaller, shifted + segmented
# target matrix:
cIx = np.zeros((14, ixToRaw.size), dtype=int)
if ~self.flagAverageDim[2]:
if ~self.average_dim[2]:
cIx[0, :] = self.Lin[ixToRaw] - self.skipLin
if ~self.flagAverageDim[3]:
if ~self.average_dim[3]:
cIx[1, :] = self.Par[ixToRaw] - self.skipPar
if ~self.flagAverageDim[4]:
if ~self.average_dim[4]:
cIx[2, :] = self.Sli[ixToRaw]
if ~self.flagAverageDim[5]:
if ~self.average_dim[5]:
cIx[3, :] = self.Ave[ixToRaw]
if ~self.flagAverageDim[6]:
if ~self.average_dim[6]:
cIx[4, :] = self.Phs[ixToRaw]
if ~self.flagAverageDim[7]:
if ~self.average_dim[7]:
cIx[5, :] = self.Eco[ixToRaw]
if ~self.flagAverageDim[8]:
if ~self.average_dim[8]:
cIx[6, :] = self.Rep[ixToRaw]
if ~self.flagAverageDim[9]:
if ~self.average_dim[9]:
cIx[7, :] = self.Set[ixToRaw]
if ~self.flagAverageDim[10]:
if ~self.average_dim[10]:
cIx[8, :] = self.Seg[ixToRaw]
if ~self.flagAverageDim[11]:
if ~self.average_dim[11]:
cIx[9, :] = self.Ida[ixToRaw]
if ~self.flagAverageDim[12]:
if ~self.average_dim[12]:
cIx[10, :] = self.Idb[ixToRaw]
if ~self.flagAverageDim[13]:
if ~self.average_dim[13]:
cIx[11, :] = self.Idc[ixToRaw]
if ~self.flagAverageDim[14]:
if ~self.average_dim[14]:
cIx[12, :] = self.Idd[ixToRaw]
if ~self.flagAverageDim[15]:
if ~self.average_dim[15]:
cIx[13, :] = self.Ide[ixToRaw]
# import pdb; pdb.set_trace()
......@@ -630,7 +666,7 @@ class twix_map_obj:
return N.astype(idxClass)
def _fileopen(self):
fid = open(self.filename, 'rb')
fid = open(self.fname, 'rb')
return fid
def readData(self, mem, cIxToTarg=None, cIxToRaw=None, selRange=None, selRangeSz=None, outSize=None):
......@@ -668,7 +704,7 @@ class twix_map_obj:
keepOS = np.concatenate([list(range(int(self.NCol / 4))), list(range(int(self.NCol * 3 / 4), int(self.NCol)))])
bIsReflected = self.IsReflected[cIxToRaw]
bRegrid = self.regrid and self.rampSampTrj.size > 1
bRegrid = self.regrid and self.rstrj.size > 1
slicedata = self.slicePos[cIxToRaw, :]
ro_shift = self.ROoffcenter[cIxToRaw] * int(not self.ignoreROoffcenter)
# %SRY store information about raw data correction
......@@ -700,8 +736,8 @@ class twix_map_obj:
if bRegrid:
v1 = np.array(range(1, selRangeSz[1] * blockSz + 1))
rsTrj = [self.rampSampTrj, v1]
trgTrj = np.linspace(np.min(self.rampSampTrj), np.max(self.rampSampTrj), int(self.NCol))
rsTrj = [self.rstrj, v1]
trgTrj = np.linspace(np.min(self.rstrj), np.max(self.rstrj), int(self.NCol))
trgTrj = [trgTrj, v1]
# counter for proper scaling of averages/segments
......
import numpy as np
from dataclasses import dataclass, field
from read_twix_hdr import read_twix_hdr
from twix_map_obj import twix_map_obj
from tqdm import tqdm, trange
def bitget(number, pos):
return (number >> pos) & 1
def set_bit(v, index, x):
# Set the index:th bit of v to 1 if x is truthy, else to 0, and return the new value."""
mask = 1 << index # Compute mask, an integer with just bit 'index' set.
v &= mask # Clear the bit indicated by the mask (if x is False)
if x:
v |= mask # If x was True, set the bit indicated by the mask.
return v # Return the result, we're done.
def loop_mdh_read(fid, version, Nscans, scan, measOffset, measLength):
if version == 'vb':
isVD = False
byteMDH = 128
elif version == 'vd':
isVD = True
byteMDH = 184
szScanHeader = 192 # [bytes]
szChannelHeader = 32 # [bytes]
else:
isVD = False
byteMDH = 128
import warnings
warnings.warn(f'Software version "{version}" is not supported.')
cPos = fid.tell()
n_acq = 0
allocSize = 4096
ulDMALength = byteMDH
isEOF = False
mdh_blob = np.zeros((byteMDH, 0), dtype=np.uint8)
szBlob = mdh_blob.shape[1] # pylint: disable=E1136 # pylint/issues/3139
filePos = np.zeros((0), dtype=float)
fid.seek(cPos, 0)
# constants and conditional variables
bit_0 = np.array(2 ** 0, dtype=np.uint8)
bit_5 = np.array(2 ** 5, dtype=np.uint8)
mdhStart = -byteMDH # Different to matlab - index = -128
u8_000 = np.zeros((3, 1), dtype=np.uint8)
# 20 fill bytes in VD (21:40)
# Subtract one from Idx numbers to account for indexing from 0 in python
evIdx = np.array(21 + 20 * isVD, dtype=np.uint8) - 1 # 1st byte of evalInfoMask
dmaIdx = np.array(np.arange(29, 33) + 20 * isVD, dtype=np.uint8) - 1 # to correct DMA length using NCol and NCha
if isVD:
dmaOff = szScanHeader
dmaSkip = szChannelHeader
else:
dmaOff = 0
dmaSkip = byteMDH
t = tqdm(total=np.float(str('%8.1f' % (measLength/1024**2))), desc='Scan %d/%d, read all mdhs' % (scan + 1, Nscans), leave=True)
while True:
# Read mdh as binary (uint8) and evaluate as little as possible to know...
# ... where the next mdh is (ulDMALength / ushSamplesInScan & ushUsedChannels)
# ... whether it is only for sync (MDH_SYNCDATA)
# ... whether it is the last one (MDH_ACQEND)
# evalMDH() contains the correct and readable code for all mdh entries.
try:
# read everything and cut out the mdh
data_u8 = np.fromfile(fid, dtype=np.uint8, count=int(ulDMALength))
data_u8 = data_u8[mdhStart:]
except EOFError:
import warnings
warningString = f'\nAn unexpected read error occurred at this byte offset: {cPos} ({cPos / 1024 ** 3} GiB)\n'
warningString += 'Will stop reading now.\n'
warnings.warn(warningString)
isEOF = True
break
bitMask = data_u8[evIdx] # the initial 8 bit from evalInfoMask are enough
# print(bitMask)
if ((data_u8[0:3] == u8_000).all()) or (bitMask & bit_0):
# ok, look closer if really all *4* bytes are 0
data_u8[3] = bitget(data_u8[3], 0) # ubit24: keep only 1 bit from the 4th byte
tmp = data_u8[0:4]
tmp.dtype = np.uint32
ulDMALength = float(tmp)
if (ulDMALength == 0) or (bitMask & bit_0):
cPos = cPos + ulDMALength
# jump to next full 512 bytes
if cPos % 512:
cPos = cPos + 512 - cPos % 512
# isEOF = True
break
if (bitMask & bit_5): # MDH_SYNCDATA
data_u8[3] = bitget(data_u8[3], 0) # ubit24: keep only 1 bit from the 4th byte
tmp = data_u8[0:4]
tmp.dtype = np.uint32
ulDMALength = float(tmp)
cPos = cPos + ulDMALength
continue
# pehses: the pack bit indicates that multiple ADC are packed into one
# DMA, often in EPI scans (controlled by fRTSetReadoutPackaging in IDEA)
# since this code assumes one adc (x NCha) per DMA, we have to correct
# the "DMA length"
# if mdh.ulPackBit
# it seems that the packbit is not always set correctly
tmp = data_u8[dmaIdx]
tmp.dtype = np.uint16
NCol_NCha = tmp # was float [ushSamplesInScan ushUsedChannels]
ulDMALength = dmaOff + (8 * NCol_NCha[0] + dmaSkip) * NCol_NCha[1]
n_acq = n_acq + 1
# grow arrays in batches
if n_acq > szBlob:
grownArray = np.zeros((mdh_blob.shape[0], allocSize),
dtype=np.uint8) # pylint: disable=E1136 # pylint/issues/3139
mdh_blob = np.concatenate((mdh_blob, grownArray), axis=1)
filePos = np.concatenate((filePos, np.zeros((allocSize))), axis=0)
szBlob = mdh_blob.shape[1] # pylint: disable=E1136 # pylint/issues/3139
mdh_blob[:, n_acq - 1] = data_u8
filePos[n_acq - 1] = cPos
# progress = (cPos - measOffset) / measLength
t.update(np.float(str('%8.1f' % (cPos/1024**2))))
cPos = cPos + ulDMALength
t.close()
if isEOF:
n_acq = n_acq - 1 # ignore the last attempt
# import pdb; pdb.set_trace()
filePos[n_acq] = cPos
# discard overallocation:
mdh_blob = mdh_blob[:, :n_acq]
filePos = filePos[:n_acq] # in matlab was converted to row vector
# elapsed_time = time.time() - start
# print(f'{measLength / 1024 ** 2:8.1f}MB read in {elapsed_time:4.0f} s\n')
return mdh_blob, filePos, isEOF
def evalMDH(mdh_blob, version):
if version == 'vd':
isVD = True
mdh_blob = np.concatenate((mdh_blob[0:20, :], mdh_blob[40:, :]), axis=0) # remove 20 unnecessary bytes
else:
isVD = False
Nmeas = mdh_blob.shape[1]
ulPackBit = bitget(mdh_blob[3, :], 2)
ulPCI_rx = set_bit(mdh_blob[3, :], 7, False) # keep 6 relevant bits
ulPCI_rx = set_bit(ulPCI_rx, 8, False)
mdh_blob[3, :] = bitget(mdh_blob[3, :], 1) # ubit24: keep only 1 bit from the 4th byte
data_uint32 = np.ascontiguousarray(mdh_blob[0:76, :].transpose())
data_uint32.dtype = np.uint32
data_uint16 = np.ascontiguousarray(mdh_blob[28:, :].transpose())
data_uint16.dtype = np.uint16
data_single = np.ascontiguousarray(mdh_blob[68:, :].transpose())
data_single.dtype = np.single
@dataclass
class MDH: # byte pos
ulPackBit: np.uint8
ulPCI_rx: np.uint8
SlicePos: np.single
aushIceProgramPara: np.uint16
aushFreePara: np.uint16
lMeasUID: np.uint32 = data_uint32[:, 2 - 1] # 5 : 8
ulScanCounter: np.uint32 = data_uint32[:, 3 - 1] # 9 : 12
ulTimeStamp: np.uint32 = data_uint32[:, 4 - 1] # 13 : 16
ulPMUTimeStamp: np.uint32 = data_uint32[:, 5 - 1] # 17 : 20
aulEvalInfoMask: np.uint32 = data_uint32[:, 5:7] # 21 : 28
ushSamplesInScan: np.uint16 = data_uint16[:, 1 - 1] # 29 : 30
ushUsedChannels: np.uint16 = data_uint16[:, 2 - 1] # 31 : 32
sLC: np.uint16 = data_uint16[:, 2:16] # 33 : 60
sCutOff: np.uint16 = data_uint16[:, 16:18] # 61 : 64
ushKSpaceCentreColumn: np.uint16 = data_uint16[:, 19 - 1] # 66 : 66
ushCoilSelect: np.uint16 = data_uint16[:, 20 - 1] # 67 : 68
fReadOutOffcentre: np.single = data_single[:, 1 - 1] # 69 : 72
ulTimeSinceLastRF: np.uint32 = data_uint32[:, 19 - 1] # 73 : 76
ushKSpaceCentreLineNo: np.uint16 = data_uint16[:, 25 - 1] # 77 : 78
ushKSpaceCentrePartitionNo: np.uint16 = data_uint16[:, 26 - 1] # 79 : 80
if isVD:
mdh = MDH(ulPackBit, ulPCI_rx, data_single[:, 3:10], data_uint16[:, 40:64], data_uint16[:, 64:68])
else:
mdh = MDH(ulPackBit, ulPCI_rx, data_single[:, 7:14], data_uint16[:, 26:30], data_uint16[:, 30:34])
evalInfoMask1 = mdh.aulEvalInfoMask[:, 0]
@dataclass
class MASK:
MDH_ACQEND = np.minimum(evalInfoMask1 & 2 ** 0, 1)
MDH_RTFEEDBACK = np.minimum(evalInfoMask1 & 2 ** 1, 1)
MDH_HPFEEDBACK = np.minimum(evalInfoMask1 & 2 ** 2, 1)
MDH_SYNCDATA = np.minimum(evalInfoMask1 & 2 ** 5, 1)
MDH_RAWDATACORRECTION = np.minimum(evalInfoMask1 & 2 ** 10, 1)
MDH_REFPHASESTABSCAN = np.minimum(evalInfoMask1 & 2 ** 14, 1)
MDH_PHASESTABSCAN = np.minimum(evalInfoMask1 & 2 ** 15, 1)
MDH_SIGNREV = np.minimum(evalInfoMask1 & 2 ** 17, 1)
MDH_PHASCOR = np.minimum(evalInfoMask1 & 2 ** 21, 1)
MDH_PATREFSCAN = np.minimum(evalInfoMask1 & 2 ** 22, 1)
MDH_PATREFANDIMASCAN = np.minimum(evalInfoMask1 & 2 ** 23, 1)
MDH_REFLECT = np.minimum(evalInfoMask1 & 2 ** 24, 1)
MDH_NOISEADJSCAN = np.minimum(evalInfoMask1 & 2 ** 25, 1)
MDH_VOP = np.minimum(mdh.aulEvalInfoMask[:, 1] & 2 ** (53 - 32),
1) # WTC modified this as the original matlab code didn't make sense
MDH_IMASCAN = np.ones(Nmeas, dtype=np.uint32)
mask = MASK()
noImaScan = (mask.MDH_ACQEND | mask.MDH_RTFEEDBACK | mask.MDH_HPFEEDBACK
| mask.MDH_PHASCOR | mask.MDH_NOISEADJSCAN | mask.MDH_PHASESTABSCAN
| mask.MDH_REFPHASESTABSCAN | mask.MDH_SYNCDATA
| (mask.MDH_PATREFSCAN & ~mask.MDH_PATREFANDIMASCAN))
mask.MDH_IMASCAN -= noImaScan
return mdh, mask
# Overload 'dict' to enable dot access to keys