mrsi.py 9.52 KB
Newer Older
1
2
3
4
5
6
7
#!/usr/bin/env python

# core.py - main MRS class definition
#
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
#         Will Clarke <william.clarke@ndcn.ox.ac.uk>
#
8
# Copyright (C) 2020 University of Oxford
9
10
11
12
# SHBASECOPYRIGHT

import numpy as np
import matplotlib.pyplot as plt
William Clarke's avatar
William Clarke committed
13
14
15

from fsl_mrs.core import MRS
from fsl_mrs.utils import misc
16

17
18

class MRSI(object):
19

20
21
    def __init__(self, FID, header=None,
                 cf=None, bw=None, nucleus='1H',
22
23
24
                 mask=None, basis=None, names=None,
                 basis_hdr=None, H2O=None):

25
26
        # process H2O
        if H2O is None:
27
28
29
30
31
            H2O = np.full(FID.shape, None)
        elif H2O.shape != FID.shape:
            raise ValueError('H2O must be None or numpy array '
                             'of the same shape as FID.')

32
        # Load into properties
33
34
        self.data   = FID
        self.H2O    = H2O
35
36
37

        # process mask
        self.set_mask(mask)
38
39
40
41
42
43
44
45
46
47

        if header is not None:
            self.header = header
        elif cf is not None\
                and bw is not None:
            self.header = {'centralFrequency': cf,
                           'bandwidth': bw,
                           'ResonantNucleus': nucleus}
        else:
            raise ValueError('Either header or cf and bw must not be None.')
48
49
50
51
52
53

        # Basis
        self.basis      = basis
        self.names      = names
        self.basis_hdr  = basis_hdr

54
        # tissue segmentation
55
56
57
58
        self.csf    = None
        self.wm     = None
        self.gm     = None
        self.tissue_seg_loaded  = False
59

60
61
62
63
64
65
66
        # Helpful properties
        self.spatial_shape = self.data.shape[:3]
        self.FID_points = self.data.shape[3]
        self.num_voxels = np.prod(self.spatial_shape)
        self.num_masked_voxels = np.sum(self.mask)
        if self.names is not None:
            self.num_basis = len(names)
67
68
69
70
71
72
73
74
75

        # MRS output options
        self.conj_basis     = False
        self.no_conj_basis  = False
        self.conj_FID       = False
        self.no_conj_FID    = False
        self.rescale        = False
        self.keep           = None
        self.ignore         = None
76
        self.ind_scaling    = None
77
78

        self._store_scalings = None
79

80
81
    def __iter__(self):
        shape = self.data.shape
82
        self._store_scalings = []
83
84
85
        for idx in np.ndindex(shape[:3]):
            if self.mask[idx]:
                mrs_out = MRS(FID=self.data[idx],
86
87
88
89
90
91
                              header=self.header,
                              basis=self.basis,
                              names=self.names,
                              basis_hdr=self.basis_hdr,
                              H2O=self.H2O[idx])

92
93
94
95
                self._process_mrs(mrs_out)
                self._store_scalings.append(mrs_out.scaling)

                if self.tissue_seg_loaded:
96
97
98
                    tissue_seg = {'CSF': self.csf[idx],
                                  'WM': self.wm[idx],
                                  'GM': self.gm[idx]}
99
100
101
                else:
                    tissue_seg = None

102
103
104
105
                yield mrs_out, idx, tissue_seg

    def get_indicies_in_order(self, mask=True):
        """Return a list of iteration indicies in order"""
106
107
108
109
110
111
112
113
114
115
        out = []
        shape = self.data.shape
        for idx in np.ndindex(shape[:3]):
            if mask:
                if self.mask[idx]:
                    out.append(idx)
            else:
                out.append(idx)
        return out

116
117
    def get_scalings_in_order(self, mask=True):
        """Return a list of MRS object scalings in order"""
118
119
120
121
122
        if self._store_scalings is None:
            raise ValueError('Fetch mrs by iterable first.')
        else:
            return self._store_scalings

123
124
125
126
127
128
129
130
    def mrs_by_index(self, index):
        ''' Return MRS object by index (tuple - x,y,z).'''
        mrs_out = MRS(FID=self.data[index[0], index[1], index[2], :],
                      header=self.header,
                      basis=self.basis,
                      names=self.names,
                      basis_hdr=self.basis_hdr,
                      H2O=self.H2O[index[0], index[1], index[2], :])
131
        self._process_mrs(mrs_out)
132
        return mrs_out
133

134
    def mrs_from_average(self):
135
136
137
138
139
140
141
142
143
        '''
        Return average of all masked voxels
        as a single MRS object.
        '''
        FID = misc.volume_to_list(self.data, self.mask)
        H2O = misc.volume_to_list(self.H2O, self.mask)
        FID = sum(FID) / len(FID)
        H2O = sum(H2O) / len(H2O)

144
145
146
147
148
149
150
151
152
        mrs_out = MRS(FID=FID,
                      header=self.header,
                      basis=self.basis,
                      names=self.names,
                      basis_hdr=self.basis_hdr,
                      H2O=H2O)
        self._process_mrs(mrs_out)
        return mrs_out

153
154
    def seg_by_index(self, index):
        '''Return segmentation information by index.'''
155
        if self.tissue_seg_loaded:
156
157
158
            return {'CSF': self.csf[index],
                    'WM': self.wm[index],
                    'GM': self.gm[index]}
159
160
161
        else:
            raise ValueError('Load tissue segmentation first.')

162
163
164
165
166
    def _process_mrs(self, mrs):
        ''' Process (conjugate, rescale)
            basis and FID and apply basis operations
            to all voxels.
        '''
167
168
169
170
171
172
173
        if self.basis is not None:
            if self.conj_basis:
                mrs.conj_Basis()
            elif self.no_conj_basis:
                pass
            else:
                mrs.check_Basis(repair=True)
174

175
176
177
178
179
180
181
182
183
184
185
            mrs.keep(self.keep)
            mrs.ignore(self.ignore)

        if self.conj_FID:
            mrs.conj_FID()
        elif self.no_conj_FID:
            pass
        else:
            mrs.check_FID(repair=True)

        if self.rescale:
186
187
188
189
            mrs.rescaleForFitting(ind_scaling=self.ind_scaling)

    def plot(self, mask=True, ppmlim=(0.2, 4.2)):
        '''Plot (masked) grid of spectra.'''
190
191
192
        if mask:
            mask_indicies = np.where(self.mask)
        else:
193
194
195
196
            mask_indicies = np.where(np.full(self.mask.shape, True))
        dim1 = np.asarray((np.min(mask_indicies[0]), np.max(mask_indicies[0])))
        dim2 = np.asarray((np.min(mask_indicies[1]), np.max(mask_indicies[1])))
        dim3 = np.asarray((np.min(mask_indicies[2]), np.max(mask_indicies[2])))
197

198
199
200
        size1 = 1 + dim1[1] - dim1[0]
        size2 = 1 + dim2[1] - dim2[0]
        size3 = 1 + dim3[1] - dim3[0]
201

202
203
        ar1 = size1 / (size1 + size2)
        ar2 = size2 / (size1 + size2)
204
205

        for sDx in range(size3):
206
207
208
            fig, axes = plt.subplots(size1, size2, figsize=(20 * ar2, 20 * ar1))
            for i, j, k in zip(*mask_indicies):
                if (not self.mask[i, j, k]) and mask:
209
210
211
                    continue
                ii = i - dim1[0]
                jj = j - dim2[0]
212
213
214
                ax = axes[ii, jj]
                mrs = self.mrs_by_index([i, j, k])
                ax.plot(mrs.getAxes(ppmlim=ppmlim), np.real(mrs.get_spec(ppmlim=ppmlim)))
215
216
217
                ax.invert_xaxis()
                ax.set_xticks([])
                ax.set_yticks([])
218
219
220
221
222
223
224
225
226
            plt.subplots_adjust(left=0.03,  # the left side of the subplots of the figure
                                right=0.97,   # the right side of the subplots of the figure
                                bottom=0.01,  # the bottom of the subplots of the figure
                                top=0.95,     # the top of the subplots of the figure
                                wspace=0,  # the amount of width reserved for space between subplots,
                                hspace=0)
            fig.suptitle(f'Slice {sDx}')
            plt.show()

227
    def __str__(self):
228
229
230
231
        return f'MRSI with shape {self.data.shape}\n' \
               f'Number of voxels = {self.num_voxels}\n' \
               f'Number of masked voxels = {self.num_masked_voxels}'

232
    def __repr__(self):
233
        return str(self)
234
235

    def set_mask(self, mask):
236
237
        """ Load mask as numpy array."""
        if mask is None:
238
            mask = np.full(self.data.shape[0:3], True)
239
240
        elif mask.shape[0:3] == self.data.shape[0:3]:
            mask = mask != 0.0
241
        else:
242
243
244
            raise ValueError(f'Mask must be None or numpy array of the same shape as FID.'
                             f' Mask {mask.shape[0:3]}, FID {self.data.shape[0:3]}.')

245
246
        self.mask = mask
        self.num_masked_voxels = np.sum(self.mask)
247
248

    def set_tissue_seg(self, csf, wm, gm):
249
250
        """ Load tissue segmentation as numpy arrays."""
        if (csf.shape != self.spatial_shape) or (wm.shape != self.spatial_shape) or (gm.shape != self.spatial_shape):
251
252
253
            raise ValueError(f'Tissue segmentation arrays have wrong shape '
                             f'(CSF:{csf.shape}, GM:{gm.shape}, WM:{wm.shape}).'
                             f' Must match FID ({self.spatial_shape}).')
254
255
256
257
258
259

        self.csf = csf
        self.wm = wm
        self.gm = gm
        self.tissue_seg_loaded = True

William Clarke's avatar
William Clarke committed
260
261
262
    def list_to_matched_array(self, data_list, indicies=None, cleanup=True, dtype=float):
        '''Convert 3D or 4D array of data indexed from an mrsi object
        to a  numpy array matching the shape of the mrsi data.'''
263
        if indicies is None:
264
265
            indicies = self.get_indicies_in_order()

266
267
268
        nt = data_list[0].size
        if nt > 1:
            data = np.zeros(self.spatial_shape + (nt,), dtype=dtype)
269
        else:
270
            data = np.zeros(self.spatial_shape, dtype=dtype)
271

272
        for d, ind in zip(data_list, indicies):
273
            data[ind] = d
274

275
276
277
        if cleanup:
            data[np.isnan(data)] = 0
            data[np.isinf(data)] = 0
278
279
280
            data[data < 1e-10]   = 0
            data[data > 1e10]    = 0

William Clarke's avatar
William Clarke committed
281
        return data