core.py 16.5 KB
Newer Older
Saad Jbabdi's avatar
Saad Jbabdi committed
1
2
3
4
5
6
7
8
9
10
#!/usr/bin/env python

# core.py - main MRS class definition
#
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
#
# Copyright (C) 2019 University of Oxford 
# SHBASECOPYRIGHT


Saad Jbabdi's avatar
Saad Jbabdi committed
11
12
import warnings
from os.path import isfile
Saad Jbabdi's avatar
Saad Jbabdi committed
13
14

from fsl_mrs.utils import mrs_io as io
15
from fsl_mrs.utils import misc
16
from fsl_mrs.utils.constants import *
Saad Jbabdi's avatar
Saad Jbabdi committed
17
18

import numpy as np
19

Saad Jbabdi's avatar
Saad Jbabdi committed
20
21
22
23
24
25
26
27
28
29
30

#------------------------------------------------
#
#
#------------------------------------------------




class MRS(object):
    """
31
      MRS Class - container for FID, Basis, and sequence info
Saad Jbabdi's avatar
Saad Jbabdi committed
32
    """
33
    def __init__(self,FID=None,header=None,basis=None,names=None,basis_hdr=None,H2O=None,cf=None,bw=None):
Saad Jbabdi's avatar
Saad Jbabdi committed
34
35
36
37
38
39
40

        # If FID and basis are files then read data from file
        #if FID is not None and basis is not None:
        #    if isfile(FID) and isfile(basis):
        #        self.from_files(FID,basis)
        #        return
    
41
        # Read in class data input
42
43
        # (now copying the data - looks ugly but better than referencing.
        # now I can run multiple times with different setups)
Saad Jbabdi's avatar
Saad Jbabdi committed
44
45
46
47
48
        if FID is not None:
            self.set_FID(FID)
        else:
            return
        
49
50
51
52
53
        if H2O is not None:
            self.H2O           = H2O.copy()
        else:
            self.H2O           = None

Saad Jbabdi's avatar
Saad Jbabdi committed
54
                
55
56
57
58
59
60
61
        # Set FID class attributes
        if header is not None:
            self.set_acquisition_params(centralFrequency=header['centralFrequency'],bandwidth=header['bandwidth'])
        elif (cf is not None) and (bw is not None):
            self.set_acquisition_params(centralFrequency=cf,bandwidth=bw)
        else:
            raise ValueError('You must pass a header or bandwidth and central frequency.')
62
        
63
        # Set Basis info
64
        if basis is not None:
65
            self.basis          = basis.copy()
66
67
68
            # Handle single basis spectra
            if self.basis.ndim==1:
                self.basis = self.basis[:,np.newaxis]
69
70
71
            # Assume that there will always be more timepoints than basis spectra.
            if self.basis.shape[0] < self.basis.shape[1]:
                self.basis = self.basis.T
72
73
            self.numBasis       = self.basis.shape[1]            
            self.numBasisPoints = self.basis.shape[0]
74
            
75
76
            if (names is not None) and (basis_hdr is not None):
                self.names         = names.copy()
Saad Jbabdi's avatar
Saad Jbabdi committed
77
                self.set_acquisition_params_basis(1/basis_hdr['bandwidth'])
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            else:
                raise ValueError('Pass basis names and header with basis.')

            # Now interpolate the basis to the same time axis.
            self.resample_basis()

        else:
            self.basis         = None
            self.names         = None
            self.numBasis      = None
            self.basis_dwellTime   = None
            self.basis_bandwidth   = None

        # Other properties
        self.metab_groups      = None
        
94
            
Saad Jbabdi's avatar
Saad Jbabdi committed
95
96
97
98
99
100
    def from_files(self,FID_file,Basis_file):
        FID,FIDheader       = io.read_FID(FID_file)
        basis,names,Bheader = io.read_basis(Basis_file)

        cf = FIDheader['centralFrequency']
        bw = FIDheader['bandwidth']
101

Saad Jbabdi's avatar
Saad Jbabdi committed
102
103
104
105
106
107
108
109
110
        MRSArgs = {'bw':bw,'cf':cf,
                   'basis':basis,'basis_hdr':Bheader[0],
                   'names':names}

        self.__init__(FID=FID,**MRSArgs)

        return

        
111
112
    def __str__(self):
        out  = '------- MRS Object ---------\n'
113
        out += '     FID.shape             = {}\n'.format(self.FID.shape)        
114
115
116
117
        out += '     FID.centralFreq (MHz) = {}\n'.format(self.centralFrequency/1e6)
        out += '     FID.centralFreq (T)   = {}\n'.format(self.centralFrequency/H1_gamma/1e6)        
        out += '     FID.bandwidth (Hz)    = {}\n'.format(self.bandwidth)
        out += '     FID.dwelltime (s)     = {}\n'.format(self.dwellTime)
118
119
120
121
        if self.basis is not None:
            out += '     basis.shape           = {}\n'.format(self.basis.shape)
            out += '     Metabolites           = {}\n'.format(self.names)
            out += '     numBasis              = {}\n'.format(self.numBasis)
122
123
124
        out += '     timeAxis              = {}\n'.format(self.timeAxis.shape)
        out += '     freqAxis              = {}\n'.format(self.frequencyAxis.shape)
        
125
126
127
        return out

    
Saad Jbabdi's avatar
Saad Jbabdi committed
128
    # Acquisition parameters
129
    def set_acquisition_params(self,centralFrequency,bandwidth):
Saad Jbabdi's avatar
Saad Jbabdi committed
130
131
132
133
134
        """
          Set useful params for fitting

          Parameters
          ----------
135
136
137
          centralFrequency : float  (unit=Hz)
          bandwidth : float (unit=Hz)
          echotime : float (unit=sec)
Saad Jbabdi's avatar
Saad Jbabdi committed
138
139

        """
140
141
        # Store CF in Hz
        self.centralFrequency = misc.checkCFUnits(centralFrequency)        
142

143
144
        self.bandwidth        = bandwidth 
        
145
        self.dwellTime        = 1/self.bandwidth
146

147
148
149
150
151
152
153
154
        axes = misc.calculateAxes(self.bandwidth,
                                  self.centralFrequency,
                                  self.numPoints)

        self.timeAxis         = axes['time']  
        self.frequencyAxis    = axes['freq']         
        self.ppmAxis          = axes['ppm']  
        self.ppmAxisShift     = axes['ppmshift']  
155
156
157
158
        self.ppmAxisFlip      = np.flipud(self.ppmAxisShift)
        # turn into column vectors
        self.timeAxis         = self.timeAxis[:,None]
        self.frequencyAxis    = self.frequencyAxis[:,None]
159
        self.ppmAxisShift     = self.ppmAxisShift[:,None]
160
161
162
163
164
165
166
167
168
169
170


    def set_acquisition_params_basis(self,dwelltime):
        """
           sets basis-specific timing params
        """
        # Basis has different dwelltime
        self.basis_dwellTime     = dwelltime
        self.basis_bandwidth     = 1/dwelltime
        self.basis_frequencyAxis = np.linspace(-self.basis_bandwidth/2,
                                               self.basis_bandwidth/2,
171
                                               self.numBasisPoints)
172
        self.basis_timeAxis      = np.linspace(self.basis_dwellTime,
173
174
                                               self.basis_dwellTime*self.numBasisPoints,
                                               self.numBasisPoints)
175

176
177
178
179
180
181
182
183
    def getSpectrum(self,ppmlim=None,shift=True):
        spectrum = misc.FIDToSpec(self.FID)
        f,l = self.ppmlim_to_range(ppmlim,shift=shift)
        return spectrum[f:l]
    
    def getAxes(self,axis='ppmshift',ppmlim=None):        
        if axis.lower() == 'ppmshift':
            f,l = self.ppmlim_to_range(ppmlim,shift=True)
184
            return np.squeeze(self.ppmAxisShift[f:l])
185
186
        elif axis.lower() == 'ppm':
            f,l = self.ppmlim_to_range(ppmlim,shift=False)
187
            return np.squeeze(self.ppmAxis[f:l])
188
189
        elif axis.lower() == 'freq':
            f,l = self.ppmlim_to_range(ppmlim,shift=False)
190
            return np.squeeze(self.frequencyAxis[f:l])
191
        elif axis.lower() == 'time':
192
            return np.squeeze(self.timeAxis)
193
194
195
        else:
            raise ValueError('axis must be one of ppmshift, ppm, freq or time.')

196
    def ppmlim_to_range(self,ppmlim=None,shift=True):
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        """
           turns ppmlim into data range

           Parameters:
           -----------

           ppmlim : tuple

           Outputs:
           --------

           int : first position
           int : last position
        """
        if ppmlim is not None:
212
213
214
215
216
217
            if shift:
                ppm2range = lambda x: np.argmin(np.abs(self.ppmAxisShift-x))
            else:
                ppm2range = lambda x: np.argmin(np.abs(self.ppmAxis-x))
            first = ppm2range(ppmlim[0])
            last  = ppm2range(ppmlim[1])
218
219
220
221
222
223
224
225
            if first>last:
                first,last = last,first
        else:
            first,last = 0,self.numPoints 

        return int(first),int(last)


226
    def resample_basis(self):
227
        """
228
           Usually the basis is simulated using different timings/number of points
229
230
231
           This interpolates the basis to match the FID
        """
        # RESAMPLE BASIS FUNCTION
232
233
234
        # bdt    = self.basis_dwellTime
        # bbw    = self.basis_bandwidth
        # bn     = self.numBasisPoints
235
        
236
237
        # bt     = np.linspace(bdt,bdt*bn,bn)-bdt
        # fidt   = self.timeAxis.flatten()-self.dwellTime
238
        
239
240
        # f      = interp1d(bt,self.basis,axis=0)
        # newiFB = f(fidt)       
241
        
242
243
244
245
        self.basis = misc.ts_to_ts(self.basis,self.basis_dwellTime,self.dwellTime,self.numPoints)
        self.basis_dwellTime = self.dwellTime
        self.basis_bandwidth = 1/self.dwellTime
        self.numBasisPoints = self.numPoints
Saad Jbabdi's avatar
Saad Jbabdi committed
246

247
        
Saad Jbabdi's avatar
Saad Jbabdi committed
248
    # Helper functions
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    def processForFitting(self,ppmlim=(.2,4.2),):
        """ Apply rescaling and run the conjugation checks"""
        self.check_FID(ppmlim=ppmlim,repair=True)
        self.check_Basis(ppmlim=ppmlim,repair=True)
        self.rescaleForFitting()

    def rescaleForFitting(self,scale=100):
        """ Apply rescaling across data, basis and H20"""
        
        scaledFID,scaling = misc.rescale_FID(self.FID,scale=scale)
        self.set_FID(scaledFID)
        if self.H2O is not None:
            self.H2O *= scaling

        if self.basis is not None:
            self.basis,_ = misc.rescale_FID(self.basis,scale=scale)

William Clarke's avatar
William Clarke committed
266
    def check_FID(self,ppmlim=(.2,4.2),repair=False):
267
268
269
270
271
272
273
        """
           Check if FID needs to be conjugated
           by looking at total power within ppmlim range

        Parameters
        ----------
        ppmlim : list
William Clarke's avatar
William Clarke committed
274
        repair : if True applies conjugation to FID
275
276
277
278
279
280
281

        Returns
        -------
        0 if check successful and -1 if not (also issues warning)

        """
        first,last = self.ppmlim_to_range(ppmlim)
282
283
        Spec1 = np.real(misc.FIDToSpec(self.FID))[first:last]
        Spec2 = np.real(misc.FIDToSpec(np.conj(self.FID)))[first:last]
Saad Jbabdi's avatar
Saad Jbabdi committed
284
285
        
        if np.linalg.norm(misc.detrend(Spec1,deg=4)) < np.linalg.norm(misc.detrend(Spec2,deg=4)):
William Clarke's avatar
William Clarke committed
286
            if repair is False:
287
                warnings.warn('YOU MAY NEED TO CONJUGATE YOUR FID!!!')
288
289
290
291
292
293
294
295
                return -1
            else:
                self.conj_FID()
                return 1
            
        return 0

    def conj_FID(self):
Saad Jbabdi's avatar
Saad Jbabdi committed
296
297
298
        """
        Conjugate FID and recalculate spectrum
        """
299
        self.FID  = np.conj(self.FID)
300
        self.Spec = misc.FIDToSpec(self.FID)
301

William Clarke's avatar
William Clarke committed
302
    def check_Basis(self,ppmlim=(.2,4.2),repair=False):
303
304
305
306
307
308
309
        """
           Check if Basis needs to be conjugated
           by looking at total power within ppmlim range

        Parameters
        ----------
        ppmlim : list
William Clarke's avatar
William Clarke committed
310
        repair : if True applies conjugation to basis
311
312
313
314
315
316
317
318
319
320
321
322
323
324

        Returns
        -------
        0 if check successful and -1 if not (also issues warning)

        """
        first,last = self.ppmlim_to_range(ppmlim)

        conjOrNot = []
        for b in self.basis.T:
            Spec1 = np.real(misc.FIDToSpec(b))[first:last]
            Spec2 = np.real(misc.FIDToSpec(np.conj(b)))[first:last]            
            if np.linalg.norm(misc.detrend(Spec1,deg=4)) < np.linalg.norm(misc.detrend(Spec2,deg=4)):
                conjOrNot.append(1.0)
Saad Jbabdi's avatar
Saad Jbabdi committed
325
326
            else:
                conjOrNot.append(0.0)
327
328

        if (sum(conjOrNot)/len(conjOrNot))>0.5:
William Clarke's avatar
William Clarke committed
329
            if repair is False:
330
331
332
333
334
335
336
337
338
339
340
341
342
343
                warnings.warn('YOU MAY NEED TO CONJUGATE YOUR BASIS!!!')
                return -1
            else:
                self.conj_Basis()
                return 1
            
        return 0

    def conj_Basis(self):
        """
        Conjugate FID and recalculate spectrum
        """
        self.basis  = np.conj(self.basis)

Saad Jbabdi's avatar
Saad Jbabdi committed
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    def ignore(self,metabs):
        """
          Ignore a subset of metabolites by removing them from the basis

          Parameters
          ----------

          metabs: list
        
        """
        if self.basis is None:
            raise Exception('You must first specify a basis before ignoring a subset of it!')

        if metabs is not None:
            for m in metabs:
                idx = self.names.index(m)
                self.names.pop(idx)
                self.basis = np.delete(self.basis,idx,axis=1)
            self.numBasis = len(self.names)
363
364
365
366
367
368
369
370
371

    def keep(self,metabs):
        """
          Keep a subset of metabolites by removing all others from basis

          Parameters
          ----------

          metabs: list
Saad Jbabdi's avatar
Saad Jbabdi committed
372
        
373
374
        """
        if metabs is not None:
375
            metabs = [m for m in self.names if m not in metabs]
376
            self.ignore(metabs)
Saad Jbabdi's avatar
Saad Jbabdi committed
377
            
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407

    def add_peak(self,ppm,name,gamma=0,sigma=0):
        """
           Add peak to basis
        """

        peak = misc.create_peak(self,ppm,gamma,sigma)[:,None]
        self.basis = np.append(self.basis,peak,axis=1)
        self.names.append(name)
        self.numBasis += 1

    def add_MM_peaks(self,ppmlist=None,gamma=0,sigma=0):
        """
           Add macromolecule list
           
        Parameters
        ----------
    
        ppmlist : default is [1.7,1.4,1.2,2.0,0.9]

        gamma,sigma : float parameters of Voigt blurring
        """
        if ppmlist is None:
            ppmlist = [1.7,1.4,1.2,2.0,0.9]
        names   = ['MM'+'{:.0f}'.format(i*10).zfill(2) for i in ppmlist]

        for name,ppm in zip(names,ppmlist):
            self.add_peak(ppm,name,gamma,sigma)

        return len(ppmlist)
408

409
410
411
412
413
414
415

    def set_FID(self,FID):
        """
          Sets the FID and calculates spectrum
        """
        self.FID         = FID.copy()
        self.numPoints   = self.FID.size
416
        self.Spec        = misc.FIDToSpec(self.FID)
Saad Jbabdi's avatar
Saad Jbabdi committed
417
        
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
                  
    # I/O functions  [NOW OBSOLETE?]
    # @staticmethod 
    # def read(filename,TYPE='RAW'):
    #     """
    #       Read data file

    #       Parameters
    #       ----------

    #       filename : string
    #       TYPE     : string

    #       Outputs
    #       -------

    #       numpy array : data 
    #       string list : header information 
    #     """
    #     if TYPE == 'RAW':
    #         data,header = io.readLCModelRaw(filename)
    #     else:
    #         raise Exception('Unknow file type {}'.format(TYPE))
    #     return data, header
Saad Jbabdi's avatar
Saad Jbabdi committed
442
    
443
444
445
    # def read_data(self,filename,TYPE='RAW'):
    #     """
    #       Read data file and update acq params
Saad Jbabdi's avatar
Saad Jbabdi committed
446

447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    #       Parameters
    #       ----------

    #       filename : string
    #       TYPE     : string

    #     """
    #     self.datafile = filename
    #     FID, header   = self.read(filename,TYPE)
    #     self.set_FID(FID)

    #     if header['centralFrequency'] is None:
    #         header['centralFrequency'] = 123.2E6
    #         warnings.warn('Cannot determine central Frequency from input. Setting to default of 123.2E6 Hz')
    #     if header['bandwidth'] is None:
    #         header['bandwidth'] = 4000
    #         warnings.warn('Cannot determine bandwidth. Setting to default of 4000Hz.')
464
465

        
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    #     self.set_acquisition_params(centralFrequency=header['centralFrequency'],
    #                                 bandwidth=header['bandwidth'])

    # def read_basis_files(self,basisfiles,TYPE='RAW',ignore=[]):
    #     """
    #        Reads basis file and extracts name of metabolite from file name
    #        Assumes .RAW files are FIDs (not spectra)
    #        Should change this to reading metabolite name from header
    #     """
    #     self.numBasis = 0
    #     self.basis = []
    #     self.names = []
    #     for iDx,file in enumerate(basisfiles):
    #         data,_ = self.read(file,TYPE)
    #         name = os.path.splitext(os.path.split(file)[-1])[-2]
    #         if name not in ignore:
    #             self.names.append(name)
    #             self.basis.append(data)
    #             self.numBasis +=1
    #     self.basis = np.asarray(self.basis).astype(np.complex).T
    #     #self.basis = self.basis - self.basis.mean(axis=0)
    
    # def read_basis_from_folder(self,folder,TYPE='RAW',ignore=[]):
    #     """
    #        Reads all .RAW files from folder assuming they are all metabolite FIDs
    #     """
    #     basisfiles = sorted(glob.glob(os.path.join(folder,'*.'+TYPE)))
    #     self.read_basis_files(basisfiles,ignore=ignore)

    # def read_basis_from_file(self,filename):
    #     """
    #        Reads single basis (.BASIS) file assuming it is spectra (not FIDs)
    #     """
    #     self.basis, self.names, header = io.readLCModelBasis(filename,self.numPoints)

    #     if header['dwelltime'] is not None:
    #         self.set_acquisition_params_basis(header['dwelltime'])
    #         self.resample_basis()
        
    #     self.numBasis = len(self.names)
Saad Jbabdi's avatar
Saad Jbabdi committed
506

507
508
509
510
511
    # def read_h2o(self,filename,TYPE='RAW'):
    #     """
    #        Reads H2O file 
    #     """
    #     self.H2O, header = self.read(filename, TYPE)
Saad Jbabdi's avatar
Saad Jbabdi committed
512
513