plotting.py 47.9 KB
Newer Older
Saad Jbabdi's avatar
Saad Jbabdi committed
1
2
3
4
5
#!/usr/bin/env python

# plotting.py - MRS plotting helper functions
#
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
Saad Jbabdi's avatar
Saad Jbabdi committed
6
#         Will Clarke <william.clarke@ndcn.ox.ac.uk>
Saad Jbabdi's avatar
Saad Jbabdi committed
7
#
8
# Copyright (C) 2019 University of Oxford
Saad Jbabdi's avatar
Saad Jbabdi committed
9
10
# SHBASECOPYRIGHT

11
12
13
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd
14
from fsl_mrs.utils import mrs_io
Saad Jbabdi's avatar
Saad Jbabdi committed
15
16
17
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
18
from plotly import tools
Saad Jbabdi's avatar
Saad Jbabdi committed
19
20
21
import nibabel as nib
import scipy.ndimage as ndimage
import itertools as it
Saad Jbabdi's avatar
Saad Jbabdi committed
22

23
from fsl_mrs.utils.misc import FIDToSpec, SpecToFID, limit_to_range
24

Saad Jbabdi's avatar
Saad Jbabdi committed
25
26
27
28
29

def FID2Spec(x):
    """
       Turn FID to spectrum for plotting
    """
30
    x = FIDToSpec(x)
Saad Jbabdi's avatar
Saad Jbabdi committed
31
32
33
    return x


34
35
def plot_fit(mrs, pred=None, ppmlim=(0.40, 4.2),
             out=None, baseline=None, proj='real'):
Saad Jbabdi's avatar
Saad Jbabdi committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    """
       Main function for plotting a model fit
       Parameters
       ----------
       mrs    : MRS Object
       pred   : array-like
              predicted FID. If not provided, tries to get it from mrs object
       ppmlim : tuple
              (MIN,MAX)
       out    : string
              output figure filename
       proj   : string
              one of 'real', 'imag', 'abs', or 'angle'

    """

52
    def axes_style(plt, ppmlim, label=None, xticks=None):
Saad Jbabdi's avatar
Saad Jbabdi committed
53
54
55
56
57
        plt.xlim(ppmlim)
        plt.gca().invert_xaxis()
        plt.xlabel(label)
        plt.gca().set_xticks(xticks)
        plt.minorticks_on()
58
59
        plt.grid(b=True, axis='x', which='major', color='k', linestyle='--', linewidth=.3)
        plt.grid(b=True, axis='x', which='minor', color='k', linestyle=':', linewidth=.3)
Saad Jbabdi's avatar
Saad Jbabdi committed
60

61
62
63
    def doPlot(data, c='b', linewidth=1, linestyle='-', xticks=None):
        plt.plot(mrs.getAxes(), data, color=c, linewidth=linewidth, linestyle=linestyle)
        axes_style(plt, ppmlim, label='Chemical shift (ppm)', xticks=xticks)
Saad Jbabdi's avatar
Saad Jbabdi committed
64
65
66
67
68
69

    # Prepare data for plotting
    data = FID2Spec(mrs.FID)
    if pred is None:
        pred = mrs.pred
    pred = FID2Spec(pred)
70
71
    if baseline is not None:
        baseline = FID2Spec(baseline)
Saad Jbabdi's avatar
Saad Jbabdi committed
72

73
    first, last = mrs.ppmlim_to_range(ppmlim=ppmlim, shift=True)
74
75
76

    # turn to real numbers
    if proj == "real":
77
        data, pred = np.real(data), np.real(pred)
78
79
        if baseline is not None:
            baseline = np.real(baseline)
80
    elif proj == "imag":
81
        data, pred = np.imag(data), np.imag(pred)
82
83
        if baseline is not None:
            baseline = np.imag(baseline)
84
    elif proj == "abs":
85
        data, pred = np.abs(data), np.abs(pred)
86
87
        if baseline is not None:
            baseline = np.abs(baseline)
88
    elif proj == "angle":
89
        data, pred = np.angle(data), np.angle(pred)
90
91
        if baseline is not None:
            baseline = np.angle(baseline)
92

93
94
95
96
97
98
    if first > last:
        first, last = last, first

    m = min(data[first:last].min(), pred[first:last].min())
    M = max(data[first:last].max(), pred[first:last].max())
    ylim = (m - np.abs(M) / 10, M + np.abs(M) / 10)
99

Saad Jbabdi's avatar
Saad Jbabdi committed
100
    # Create the figure
101
    plt.figure(figsize=(9, 10))
Saad Jbabdi's avatar
Saad Jbabdi committed
102

103
104
    # Subplots
    gs = gridspec.GridSpec(2, 1,
Saad Jbabdi's avatar
Saad Jbabdi committed
105
106
                           height_ratios=[1, 20])

107
    plt.subplot(gs[0])
Saad Jbabdi's avatar
Saad Jbabdi committed
108
    # Start by plotting error
109
110
111
    xticks = np.arange(ppmlim[0], ppmlim[1] + .2, .2)
    plt.plot(mrs.getAxes(), data_proj(data - pred, proj), c='k', linewidth=1, linestyle='-')
    axes_style(plt, ppmlim, xticks=xticks)
Saad Jbabdi's avatar
Saad Jbabdi committed
112
113
    plt.gca().set_xticklabels([])

114
115
116
117
    plt.subplot(gs[1])

    doPlot(data, c='k', linewidth=.5, xticks=xticks)
    doPlot(pred, c='#cc0000', linewidth=1, xticks=xticks)
118
    if baseline is not None:
119
        doPlot(baseline, c='k', linewidth=.5, xticks=xticks)
120
121

    # plot y=0
122
    doPlot(data * 0, c='k', linestyle=':', linewidth=1, xticks=xticks)
123

124
    plt.legend(['data', 'model fit'])
125

Saad Jbabdi's avatar
Saad Jbabdi committed
126
    plt.tight_layout()
127
    plt.ylim(ylim)
128

Saad Jbabdi's avatar
Saad Jbabdi committed
129
130
131
132
    if out is not None:
        plt.savefig(out)

    return plt.gcf()
133

134
135

def plot_fit_new(mrs, ppmlim=(0.40, 4.2)):
136
137
    """
        plot model fitting plus baseline
138

139
140
141
        mrs : MRS object
        ppmlim : tuple
    """
142
143
144
145
    axis = mrs.getAxes()
    spec = np.flipud(np.fft.fftshift(mrs.get_spec()))
    pred = FIDToSpec(mrs.pred)
    pred = np.flipud(np.fft.fftshift(pred))
146
147
148

    if mrs.baseline is not None:
        B = np.flipud(np.fft.fftshift(mrs.baseline))
149
150
151
152
153
154
155
156

    first = np.argmin(np.abs(axis - ppmlim[0]))
    last = np.argmin(np.abs(axis - ppmlim[1]))
    if first > last:
        first, last = last, first

    plt.figure(figsize=(9, 10))
    plt.plot(axis[first:last], spec[first:last])
157
    plt.gca().invert_xaxis()
158
    plt.plot(axis[first:last], pred[first:last], 'r')
159
    if mrs.baseline is not None:
160
        plt.plot(axis[first:last], B[first:last], 'k')
161
162
163

    # style stuff
    plt.minorticks_on()
164
165
    plt.grid(b=True, axis='x', which='major', color='k', linestyle='--', linewidth=.3)
    plt.grid(b=True, axis='x', which='minor', color='k', linestyle=':', linewidth=.3)
166

167
    return plt.gcf()
168
169


170
def plot_waterfall(mrs, ppmlim=(0.4, 4.2), proj='real', mod=True):
171
172
    """
       Plot individual metabolit spectra
173

174
175
176
177
178
179
180
181
       Parameters
       ----------

       ppmlim : tuple
       proj   : either 'real' or 'imag' or 'abs' or 'angle'
       mod    : True or False
                whether to multiply by estimated concentrations or not
    """
Saad Jbabdi's avatar
Saad Jbabdi committed
182
    gs = gridspec.GridSpec(mrs.numBasis, 1)
183
    fig = plt.figure(figsize=(5, 10))
Saad Jbabdi's avatar
Saad Jbabdi committed
184
185

    for i in range(mrs.numBasis):
186
        plt.subplot(gs[i])
Saad Jbabdi's avatar
Saad Jbabdi committed
187
188
189
190
        plt.xlim(ppmlim)
        plt.gca().invert_xaxis()
        plt.gca().set_xticklabels([])
        plt.gca().set_yticklabels([])
191
        plt.gca().set_ylabel(mrs.names[i], rotation='horizontal')
Saad Jbabdi's avatar
Saad Jbabdi committed
192
193
        plt.box(False)

194
        if mod and mrs.con is not None:
195
            data = FID2Spec(mrs.con[i] * mrs.basis[:, i])
Saad Jbabdi's avatar
Saad Jbabdi committed
196
        else:
197
198
199
            data = FID2Spec(mrs.basis[:, i])
        plt.plot(mrs.getAxes(), data_proj(data, proj), c='r', linewidth=1, linestyle='-')

200
    return fig
Saad Jbabdi's avatar
Saad Jbabdi committed
201
202


203
def plot_spectrum(mrs, ppmlim=(0.0, 4.5), FID=None, proj='real', c='k'):
Saad Jbabdi's avatar
Saad Jbabdi committed
204
    """
205
       Plotting the spectrum
Saad Jbabdi's avatar
Saad Jbabdi committed
206
207
       ----------
       FID    : array-like
Saad Jbabdi's avatar
Saad Jbabdi committed
208
209
       bandwidth : float (unit = Hz)
       centralFrequency : float (unit = Hz)
Saad Jbabdi's avatar
Saad Jbabdi committed
210
211
212
213
214
215
216
       ppmlim : tuple
              (MIN,MAX)
       proj   : string
              one of 'real', 'imag', 'abs', or 'angle'

    """

217
218
219
    ppmAxisShift = mrs.getAxes(ppmlim=ppmlim)

    def axes_style(plt, ppmlim, label=None, xticks=None):
Saad Jbabdi's avatar
Saad Jbabdi committed
220
221
222
223
224
        plt.xlim(ppmlim)
        plt.gca().invert_xaxis()
        plt.xlabel(label)
        plt.gca().set_xticks(xticks)
        plt.minorticks_on()
225
226
        plt.grid(b=True, axis='x', which='major', color='k', linestyle='--', linewidth=.3)
        plt.grid(b=True, axis='x', which='minor', color='k', linestyle=':', linewidth=.3)
Saad Jbabdi's avatar
Saad Jbabdi committed
227

228
229
230
    def doPlot(data, c='b', linewidth=1, linestyle='-', xticks=None):
        plt.plot(ppmAxisShift, data, color=c, linewidth=linewidth, linestyle=linestyle)
        axes_style(plt, ppmlim, label='Chemical shift (ppm)', xticks=xticks)
Saad Jbabdi's avatar
Saad Jbabdi committed
231
232

    # Prepare data for plotting
233
    if FID is not None:
234
235
        first, last = mrs.ppmlim_to_range(ppmlim)
        data = FIDToSpec(FID)[first:last]
236
    else:
237
        data = mrs.get_spec(ppmlim=ppmlim)
Saad Jbabdi's avatar
Saad Jbabdi committed
238

239
240
241
242
    # m = min(np.real(data))
    # M = max(np.real(data))
    # ylim   = (m-np.abs(M)/10,M+np.abs(M)/10)
    # plt.ylim(ylim)
Saad Jbabdi's avatar
Saad Jbabdi committed
243
244

    # Create the figure
245
    # plt.figure(figsize=(7,7))
246
    # Some nicer x ticks on the plots
247
248
    if np.abs(ppmlim[1] - ppmlim[0]) > 2:
        xticks = np.arange(np.ceil(ppmlim[0]), np.floor(ppmlim[1]) + 0.1, 1.0)
249
    else:
250
251
252
        xticks = np.arange(np.around(ppmlim[0], 1), np.around(ppmlim[1], 1) + 0.01, 0.1)

    doPlot(data_proj(data, proj), c=c, linewidth=2, xticks=xticks)
253

Saad Jbabdi's avatar
Saad Jbabdi committed
254
    plt.tight_layout()
255
    return plt.gcf()
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270


def data_proj(x, proj):
    """Proj is one of 'real', 'imag', 'abs', or 'angle'"""
    if proj == 'real':
        return np.real(x)
    if proj == 'imag':
        return np.imag(x)
    if proj == 'abs':
        return np.abs(x)
    if proj == 'real':
        return np.angle(x)
    else:
        raise ValueError("proj should be one of 'real', 'imag', 'abs', or 'angle'.")

271

272
273
274
275
276
def plot_fid(mrs, tlim=None, FID=None, proj='real', c='k'):
    ''' Plot time domain FID'''

    time_axis = mrs.getAxes(axis='time')

277
    if FID is not None:
278
279
280
281
282
283
284
285
286
287
288
289
        data = FID
    else:
        data = mrs.FID

    data = getattr(np, proj)(data)

    plt.plot(time_axis, data, color=c, linewidth=2)

    if tlim is not None:
        plt.xlim(tlim)
    plt.xlabel('Time (s)')
    plt.minorticks_on()
290
291
    plt.grid(b=True, axis='x', which='major', color='k', linestyle='--', linewidth=.3)
    plt.grid(b=True, axis='x', which='minor', color='k', linestyle=':', linewidth=.3)
292
293
294
295

    plt.tight_layout()
    return plt.gcf()

296

297
298
299
300
301
302
303
304
305
306
307
def plot_mrs_basis(mrs, plot_spec=False, ppmlim=(0.0, 4.5)):
    """Plot the formatted basis and optionally the FID from an mrs object

    :param mrs: MRS object
    :type mrs: fsl_mrs.core.mrs.MRS
    :param plot_spec: If True plot the spectrum on same axes, defaults to False
    :type plot_spec: bool, optional
    :param ppmlim: Chemical shift plotting range, defaults to (0.0, 4.5)
    :type ppmlim: tuple, optional
    :return: Figure object
    """
308
    first, last = mrs.ppmlim_to_range(ppmlim=ppmlim)
309

310
311
312
313
314
315
316
    for idx, n in enumerate(mrs.names):
        plt.plot(mrs.getAxes(ppmlim=ppmlim),
                 np.real(FID2Spec(mrs.basis[:, idx]))[first:last],
                 label=n)

    if plot_spec:
        plt.plot(mrs.getAxes(ppmlim=ppmlim),
317
318
                 np.real(mrs.get_spec(ppmlim=ppmlim)),
                 'k', label='Data')
319
320

    plt.gca().invert_xaxis()
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    plt.xlabel('Chemical shift (ppm)')
    plt.legend()

    return plt.gcf()


def plot_basis(basis, ppmlim=(0.0, 4.5), shift=True, conjugate=False):
    """Plot the basis contained in a Basis object

    :param basis: Basis object
    :type basis: fsl_mrs.core.basis.Basis
    :param ppmlim: Chemical shift plotting limits on x axis, defaults to (0.0, 4.5)
    :type ppmlim: tuple, optional
    :param shift: Apply chemical shift referencing shift, defaults to True.
    :type shift: Bool, optional
    :param conjugate: Apply conjugation (flips frequency direction), defaults to False.
    :type conjugate: Bool, optional
    :return: Figure object
    """
    if shift:
        axis = basis.original_ppm_shift_axis
    else:
        axis = basis.original_ppm_axis
    first, last = limit_to_range(axis, ppmlim)

346
347
348
349
350
351
352
353
    n_met = basis.n_metabs
    if n_met <= 10:
        colors = plt.cm.tab10(np.linspace(0, 1, n_met))
    elif n_met <= 20:
        colors = plt.cm.tab20(np.linspace(0, 1, n_met))
    elif n_met > 20:
        colors = plt.cm.nipy_spectral(np.linspace(0, 1, n_met))

354
    ax = plt.gca()
355
356
    ax.set_prop_cycle('color', colors)

357
358
359
360
    for idx, n in enumerate(basis.names):
        FID = basis.original_basis_array[:, idx]
        if conjugate:
            FID = FID.conj()
361
362
363
        ax.plot(axis[first:last],
                np.real(FID2Spec(FID))[first:last],
                label=n)
364
365

    plt.gca().invert_xaxis()
366
367
368
369
370
    plt.xlabel('Chemical shift (ppm)')
    plt.legend()

    return plt.gcf()

371

372
373
374
def plot_spectra(MRSList, ppmlim=(0, 4.5), single_FID=None, plot_avg=True):

    plt.figure(figsize=(10, 10))
Saad Jbabdi's avatar
Saad Jbabdi committed
375
376
377
    plt.xlim(ppmlim)
    plt.gca().invert_xaxis()
    plt.minorticks_on()
378
379
    plt.grid(b=True, axis='x', which='major', color='k', linestyle='--', linewidth=.3)
    plt.grid(b=True, axis='x', which='minor', color='k', linestyle=':', linewidth=.3)
380

381
    plt.autoscale(enable=True, axis='y', tight=True)
382
383

    avg = 0
William Clarke's avatar
William Clarke committed
384
    for mrs in MRSList:
385
        data = np.real(mrs.get_spec(ppmlim=ppmlim))
William Clarke's avatar
William Clarke committed
386
        ppmAxisShift = mrs.getAxes(ppmlim=ppmlim)
Saad Jbabdi's avatar
Saad Jbabdi committed
387
        avg += data
388
        plt.plot(ppmAxisShift, data, color='k', linewidth=.5, linestyle='-')
Saad Jbabdi's avatar
Saad Jbabdi committed
389
    if single_FID is not None:
390
        data = np.real(single_FID.get_spec(ppmlim=ppmlim))
391
        plt.plot(ppmAxisShift, data, color='r', linewidth=2, linestyle='-')
Saad Jbabdi's avatar
Saad Jbabdi committed
392
    if plot_avg:
William Clarke's avatar
William Clarke committed
393
        avg /= len(MRSList)
394
395
396
397
        plt.plot(ppmAxisShift, avg, color='g', linewidth=2, linestyle='-')

    autoscale_y(plt.gca(), margin=0.05)

Saad Jbabdi's avatar
Saad Jbabdi committed
398
    return plt.gcf()
399

400
401

def autoscale_y(ax, margin=0.1):
402
403
404
405
406
    """This function rescales the y-axis based on the data that is visible given the current xlim of the axis.
    ax -- a matplotlib axes object
    margin -- the fraction of the total height of the y-data to pad the upper and lower ylims"""

    def get_bottom_top(line):
407
408
409
410
411
412
413
414
        xd = line.get_xdata()
        yd = line.get_ydata()
        hi, lo = ax.get_xlim()  # Reversed
        y_displayed = yd[((xd > lo) & (xd < hi))]
        h = np.max(y_displayed) - np.min(y_displayed)
        bot = np.min(y_displayed) - margin * h
        top = np.max(y_displayed) + margin * h
        return bot, top
415
416

    lines = ax.get_lines()
417
    bot, top = np.inf, -np.inf
418
419
420

    for line in lines:
        new_bot, new_top = get_bottom_top(line)
421
422
423
424
425
426
427
        if new_bot < bot:
            bot = new_bot
        if new_top > top:
            top = new_top

    ax.set_ylim(bot, top)

428

429
def plot_fit_pretty(mrs, pred=None, ppmlim=(0.40, 4.2), proj='real'):
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    """
       Main function for plotting a model fit
       Parameters
       ----------
       mrs    : MRS Object
       pred   : array-like
              predicted FID. If not provided, tries to get it from mrs object
       ppmlim : tuple
              (MIN,MAX)
       out    : string
              output figure filename
       proj   : string
              one of 'real', 'imag', 'abs', or 'angle'

    """

    if pred is None:
        pred = mrs.pred
448

449
450
    data = np.real(FID2Spec(mrs.FID))
    pred = np.real(FID2Spec(pred))
451
452
    err = data - pred
    x = mrs.getAxes()
453
454
455
456
457
458
459
460
461

    fig = tools.make_subplots(rows=2,
                              row_width=[10, 1],
                              shared_xaxes=True,
                              print_grid=False,
                              vertical_spacing=0)

    trace_data = go.Scatter(x=x, y=data, name='data', hoverinfo="none")
    trace_pred = go.Scatter(x=x, y=pred, name='pred', hoverinfo="none")
462
    trace_err = go.Scatter(x=x, y=err, name='error', hoverinfo="none")
463
464
465
466
467
468
469
470

    fig.append_trace(trace_err, 1, 1)
    fig.append_trace(trace_data, 2, 1)
    fig.append_trace(trace_pred, 2, 1)

    fig['layout'].update(autosize=True,
                         title=None,
                         showlegend=True,
471
                         margin={'t': 0.01, 'r': 0, 'l': 20})
472
473
474
475

    fig['layout']['xaxis'].update(zeroline=False,
                                  title='Chemical shift (ppm)',
                                  automargin=True,
476
                                  range=[ppmlim[1], ppmlim[0]])
477
478
479
480
481
    fig['layout']['yaxis'].update(zeroline=False, automargin=True)
    fig['layout']['yaxis2'].update(zeroline=False, automargin=True)

    return fig

482

483
# plotly imports
484

485
def plotly_fit(mrs, res, ppmlim=(.2, 4.2), proj='real', metabs=None, phs=(0, 0)):
486
487
    """
         plot model fitting plus baseline
488

489
490
491
492
    Parameters:
         mrs    : MRS object
         res    : ResFit Object
         ppmlim : tuple
493
494
         metabs : list of metabolite to include in pred
         phs    : display phasing in degrees and seconds
495
496
497
498

    Returns
         fig
     """
499
    def project(x, proj):
500
        if proj == 'real':
501
            return np.real(x)
502
        elif proj == 'imag':
503
            return np.imag(x)
504
        elif proj == 'angle':
505
506
507
508
509
            return np.angle(x)
        else:
            return np.abs(x)

    # Prepare the data
510
511
512
    base = FID2Spec(res.baseline)
    axis = mrs.getAxes()
    data = FID2Spec(mrs.FID)
513

514
515
516
    if ppmlim is None:
        ppmlim = res.ppmlim

517
518
519
    if metabs is not None:
        preds = []
        for m in metabs:
520
            preds.append(FID2Spec(pred(mrs, res, m, add_baseline=False)))
521
522
        preds = sum(preds)
        preds += FID2Spec(res.baseline)
523
        resid = data - preds
524
    else:
525
526
        preds = FID2Spec(res.pred)
        resid = FID2Spec(res.residuals)
527
528

    # phasing
529
530
    faxis = mrs.getAxes(axis='freq')
    phaseTerm = np.exp(1j * (phs[0] * np.pi / 180)) * np.exp(1j * 2 * np.pi * phs[1] * faxis)
531

532
533
534
535
    base *= phaseTerm
    data *= phaseTerm
    preds *= phaseTerm
    resid *= phaseTerm
536

537
538
539
540
    base = project(base, proj)
    data = project(data, proj)
    preds = project(preds, proj)
    resid = project(resid, proj)
541
542

    # y-axis range
543
544
545
546
    minval = min(np.min(base), np.min(data), np.min(preds), np.min(resid))
    maxval = max(np.max(base), np.max(data), np.max(preds), np.max(resid))
    ymin = minval - minval / 2
    ymax = maxval + maxval / 30
547
548
549
550
551
552

    # Build the plot

    # Table

    df = pd.DataFrame()
553
    df['Metab'] = res.metabs
554
    if res.concScalings['molality'] is not None:
555
        df['mMol/kg'] = np.round(res.getConc(scaling='molality'), decimals=2)
556
        df['CRLB'] = np.round(res.getUncertainties(type='molality'), decimals=2)
557
    else:
558
        df['unscaled'] = np.round(res.getConc(), decimals=2)
559
        df['CRLB'] = np.round(res.getUncertainties(type='raw'), decimals=3)
560
    df['%CRLB'] = np.round(res.getUncertainties(), decimals=1)
561
562
    if res.concScalings['internal'] is not None:
        concstr = f'/{res.concScalings["internalRef"]}'
563
564
        df[concstr] = np.round(res.getConc(scaling='internal'), decimals=2)

Saad Jbabdi's avatar
Saad Jbabdi committed
565
    tab = create_table(df)
566
567

    colors = dict(data='rgb(67,67,67)',
568
                  pred='rgb(253,59,59)',
Saad Jbabdi's avatar
Saad Jbabdi committed
569
                  base='rgb(0,150,242)',
570
                  resid='rgb(170,170,170)')
571
    line_size = dict(data=1,
572
                     pred=2,
573
                     base=1, resid=1)
574
575
576
577

    trace1 = go.Scatter(x=axis, y=data,
                        mode='lines',
                        name='data',
578
                        line=dict(color=colors['data'], width=line_size['data']),
Saad Jbabdi's avatar
Saad Jbabdi committed
579
                        )
580
    trace2 = go.Scatter(x=axis, y=preds,
581
582
                        mode='lines',
                        name='model',
583
                        line=dict(color=colors['pred'], width=line_size['pred']),
Saad Jbabdi's avatar
Saad Jbabdi committed
584
                        )
585
586
587
    trace3 = go.Scatter(x=axis, y=base,
                        mode='lines',
                        name='baseline',
588
                        line=dict(color=colors['base'], width=line_size['base']),
Saad Jbabdi's avatar
Saad Jbabdi committed
589
                        )
590
591
592
    trace4 = go.Scatter(x=axis, y=resid,
                        mode='lines',
                        name='residuals',
593
                        line=dict(color=colors['resid'], width=line_size['resid']),
Saad Jbabdi's avatar
Saad Jbabdi committed
594
                        )
595

Saad Jbabdi's avatar
Saad Jbabdi committed
596
597
598
    fig = make_subplots(rows=1, cols=2,
                        column_widths=[0.4, 0.6],
                        horizontal_spacing=0.03,
599
                        specs=[[{'type': 'table'}, {'type': 'scatter'}]])
Saad Jbabdi's avatar
Saad Jbabdi committed
600

601
602
603
604
605
    fig.add_trace(tab, row=1, col=1)
    fig.add_trace(trace1, row=1, col=2)
    fig.add_trace(trace2, row=1, col=2)
    fig.add_trace(trace3, row=1, col=2)
    fig.add_trace(trace4, row=1, col=2)
Saad Jbabdi's avatar
Saad Jbabdi committed
606

607
    fig.update_layout(template='plotly_white')
608

609
    fig.update_xaxes({'domain': [0.4, 1.]}, row=1, col=2)
Saad Jbabdi's avatar
Saad Jbabdi committed
610
611
    fig.update_xaxes(title_text='Chemical shift (ppm)',
                     tick0=2, dtick=.5,
612
                     range=[ppmlim[1], ppmlim[0]])
Saad Jbabdi's avatar
Saad Jbabdi committed
613

614
615
    fig.update_yaxes(zeroline=True,
                     zerolinewidth=1,
Saad Jbabdi's avatar
Saad Jbabdi committed
616
                     zerolinecolor='Gray',
617
618
619
620
621
                     showgrid=False, showticklabels=False,
                     range=[ymin, ymax])

    fig.layout.update({'height': 800})

622
623
624
    return fig


625
def plot_dist_approx(res, refname='Cr'):
626

William Clarke's avatar
William Clarke committed
627
628
    numOrigMetabs = len(res.original_metabs)
    n = int(np.ceil(np.sqrt(numOrigMetabs)))
629
    fig = make_subplots(rows=n, cols=n, subplot_titles=res.original_metabs)
William Clarke's avatar
William Clarke committed
630
631
632
633
634
    if refname is not None:
        ref = res.getConc()[res.metabs.index(refname)]
    else:
        ref = 1.0

635
    for i, metab in enumerate(res.original_metabs):
636
        (r, c) = divmod(i, n)
637
638
639
640
641
642
643
644
645
646
        mu = res.params[i] / ref
        sig = np.sqrt(res.crlb[i]) / ref
        x = np.linspace(mu - mu, mu + mu, 100)
        N = np.exp(-(x - mu)**2 / sig**2)
        N = N / N.sum() / (x[1] - x[0])
        t = go.Scatter(x=x, y=N, mode='lines',
                       name=metab, line=dict(color='black'))
        fig.add_trace(t, row=r + 1, col=c + 1)

    fig.update_layout(template='plotly_white',
647
648
649
                      showlegend=False,
                      font=dict(size=10),
                      title='Approximate marginal distributions (ref={})'.format(refname),
650
                      height=700, width=700)
651
    for i in fig['layout']['annotations']:
652
        i['font'] = dict(size=10, color='#ff0000')
Saad Jbabdi's avatar
Saad Jbabdi committed
653
654
    fig.update_layout(autosize=True)

655
656
657
    return fig


658
def plot_corr(res, corr=None, title='Correlation'):
659

660
661
662
    # Greys,YlGnBu,Greens,YlOrRd,Bluered,RdBu,Reds,Blues,
    # Picnic,Rainbow,Portland,Jet,Hot,Blackbody,Earth,
    # Electric,Viridis,Cividis.
663
    # n = mrs.numBasis
664
    fig = go.Figure()
665
666
    if corr is None:
        corr = res.mcmc_cor
667
    np.fill_diagonal(corr, np.nan)
668
669
    corrabs = np.abs(corr)

670
    fig.add_trace(go.Heatmap(z=corr,
671
672
673
                             x=res.original_metabs, y=res.original_metabs, colorscale='Picnic', zmid=0))

    fig.update_layout(template='plotly_white',
674
                      font=dict(size=10),
Saad Jbabdi's avatar
Saad Jbabdi committed
675
                      title=title,
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
                      width=600,
                      height=600,
                      yaxis=dict(
                          scaleanchor="x",
                          scaleratio=1,
                      ),
                      updatemenus=[
                          dict(
                              type="buttons",
                              direction="left",
                              buttons=list([
                                  dict(
                                      args=[{"z": [corr], "colorscale":'Picnic'}],
                                      label="Real",
                                      method="restyle"
                                  ),
                                  dict(
                                      args=[{"z": [corrabs], "colorscale":'Picnic'}],
                                      label="Abs",
                                      method="restyle"
                                  )
                              ]),
                              pad={"r": 10, "t": 10},
                              showactive=True,
                              x=0.11,
                              xanchor="left",
                              y=1.1,
                              yanchor="top"
                          ),
                      ])
Saad Jbabdi's avatar
Saad Jbabdi committed
706
707
    fig.update_layout(autosize=True)

708
709
    return fig

710
711

def plot_dist_mcmc(res, refname='Cr'):
712
713
714
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots

William Clarke's avatar
William Clarke committed
715
    n = int(np.ceil(np.sqrt(res.numMetabs)))
716
    fig = make_subplots(rows=n, cols=n, subplot_titles=res.metabs)
William Clarke's avatar
William Clarke committed
717
718
719
720
721
    if refname is not None:
        ref = res.fitResults[refname].mean()
    else:
        ref = 1.0

722
    for i, metab in enumerate(res.metabs):
723
        (r, c) = divmod(i, n)
724
        x = res.fitResults[metab].to_numpy() / ref
725
726
        t = go.Histogram(x=x,
                         name=metab,
727
728
729
730
731
                         histnorm='percent', marker_color='#330C73', opacity=0.75)

        fig.add_trace(t, row=r + 1, col=c + 1)

    fig.update_layout(template='plotly_white',
732
                      showlegend=False,
733
734
                      width=700,
                      height=700,
735
736
737
                      font=dict(size=10),
                      title='MCMC marginal distributions (ref={})'.format(refname))
    for i in fig['layout']['annotations']:
738
        i['font'] = dict(size=10, color='#ff0000')
739

Saad Jbabdi's avatar
Saad Jbabdi committed
740
741
    fig.update_layout(autosize=True)

742
743
    return fig

744
745

def plot_real_imag(mrs, res, ppmlim=(.2, 4.2)):
746
747
    """
         plot model fitting plus baseline
748

749
750
751
752
753
754
755
756
    Parameters:
         mrs    : MRS object
         res    : ResFit Object
         ppmlim : tuple

    Returns
         fig
     """
757
    def project(x, proj):
758
        if proj == 'real':
759
            return np.real(x)
760
        elif proj == 'imag':
761
            return np.imag(x)
762
        elif proj == 'angle':
763
764
765
766
            return np.angle(x)
        else:
            return np.abs(x)

767
768
769
770
771
772
    # Prepare the data
    axis = mrs.getAxes()
    data_real = project(FID2Spec(mrs.FID), 'real')
    pred_real = project(FID2Spec(res.pred), 'real')
    data_imag = project(FID2Spec(mrs.FID), 'imag')
    pred_imag = project(FID2Spec(res.pred), 'imag')
773
774

    # Build the plot
775
    fig = make_subplots(rows=1, cols=2, subplot_titles=['Real', 'Imag'])
776

777
    colors = dict(data='rgb(67,67,67)',
778
779
                  pred='rgb(253,59,59)',
                  base='rgb(170,170,170)')
780
    line_size = dict(data=1,
781
782
783
784
785
786
                     pred=2,
                     base=1)

    trace1 = go.Scatter(x=axis, y=data_real,
                        mode='lines',
                        name='data : real',
787
                        line=dict(color=colors['data'], width=line_size['data']))
788
789
790
    trace2 = go.Scatter(x=axis, y=pred_real,
                        mode='lines',
                        name='model : real',
791
792
793
                        line=dict(color=colors['pred'], width=line_size['pred']))
    fig.add_trace(trace1, row=1, col=1)
    fig.add_trace(trace2, row=1, col=1)
794
795
796
797

    trace1 = go.Scatter(x=axis, y=data_imag,
                        mode='lines',
                        name='data : imag',
798
                        line=dict(color=colors['data'], width=line_size['data']))
799
800
801
    trace2 = go.Scatter(x=axis, y=pred_imag,
                        mode='lines',
                        name='model : imag',
802
803
804
                        line=dict(color=colors['pred'], width=line_size['pred']))
    fig.add_trace(trace1, row=1, col=2)
    fig.add_trace(trace2, row=1, col=2)
805
806
807
808

#     fig.layout.xaxis.update({'domain': [0, .35]})
#     fig.layout.xaxis2.update({'domain': [0.4, 1.]})
    fig.layout.xaxis.update(title_text='Chemical shift (ppm)',
809
810
                            tick0=2, dtick=.5,
                            range=[ppmlim[1], ppmlim[0]])
811
812
    fig.layout.xaxis2.update(title_text='Chemical shift (ppm)',
                             tick0=2, dtick=.5,
813
                             range=[ppmlim[1], ppmlim[0]])
814

815
816
    fig.layout.yaxis2.update(zeroline=True,
                             zerolinewidth=1,
817
                             zerolinecolor='Gray',
818
819
820
821
822
823
                             showgrid=False, showticklabels=False)
    fig.layout.yaxis.update(zeroline=True,
                            zerolinewidth=1,
                            zerolinecolor='Gray',
                            showgrid=False, showticklabels=False)

824
    # Update the margins to add a title and see graph x-labels.
825
826
827
    # fig.layout.margin.update({'t':50, 'b':100})
    # fig.layout.update({'title': 'Fitting summary Real/Imag'})
    fig.update_layout(template='plotly_white')
828
    # fig.layout.update({'height':800,'width':1000})
829

830
831
832
    return fig


833
def pred(mrs, res, metab, add_baseline=True):
834
    from fsl_mrs.utils import models
835
836

    if res.model == 'lorentzian':
837
        forward = models.FSLModel_forward      # forward model
838

839
        con, gamma, eps, phi0, phi1, b = models.FSLModel_x2param(res.params, mrs.numBasis, res.g)
840
        c = con[mrs.names.index(metab)].copy()
841
        con = 0 * con
842
        con[mrs.names.index(metab)] = c
843
        x = models.FSLModel_param2x(con, gamma, eps, phi0, phi1, b)
844
845

    elif res.model == 'voigt':
846
        forward = models.FSLModel_forward_Voigt  # forward model
847

848
        con, gamma, sigma, eps, phi0, phi1, b = models.FSLModel_x2param_Voigt(res.params, mrs.numBasis, res.g)
849
        c = con[mrs.names.index(metab)].copy()
850
        con = 0 * con
851
        con[mrs.names.index(metab)] = c
852
        x = models.FSLModel_param2x_Voigt(con, gamma, sigma, eps, phi0, phi1, b)
853
854
    else:
        raise Exception('Unknown model.')
855

856
    if add_baseline:
857
858
859
        pred = forward(x, mrs.frequencyAxis,
                       mrs.timeAxis,
                       mrs.basis, res.base_poly, res.metab_groups, res.g)
860
    else:
861
862
863
864
        pred = forward(x, mrs.frequencyAxis,
                       mrs.timeAxis,
                       mrs.basis, np.zeros(res.base_poly.shape), res.metab_groups, res.g)
    pred = SpecToFID(pred)  # predict FID not Spec
865
866
    return pred

867

868
869
870
def plot_indiv_stacked(mrs, res, ppmlim=(.2, 4.2)):

    colors = dict(data='rgb(67,67,67)',
871
                  indiv='rgb(253,59,59)')
872
    line_size = dict(data=.5,
873
874
                     indiv=2)
    fig = go.Figure()
875
876
    axis = mrs.getAxes()
    y_data = np.real(FID2Spec(mrs.FID))
877
878
879
    trace1 = go.Scatter(x=axis, y=y_data,
                        mode='lines',
                        name='data',
880
                        line=dict(color=colors['data'], width=line_size['data']))
881
882
    fig.add_trace(trace1)

883
884
    for i, metab in enumerate(mrs.names):
        y_fit = np.real(FID2Spec(pred(mrs, res, metab)))
885
        trace2 = go.Scatter(x=axis, y=y_fit,
886
887
888
                            mode='lines',
                            name=metab,
                            line=dict(color=colors['indiv'], width=line_size['indiv']))
889
890
891
        fig.add_trace(trace2)

    fig.layout.xaxis.update(title_text='Chemical shift (ppm)',
892
893
894
895
896
897
                            tick0=2, dtick=.5,
                            range=[ppmlim[1], ppmlim[0]])
    fig.layout.yaxis.update(zeroline=True,
                            zerolinewidth=1,
                            zerolinecolor='Gray',
                            showgrid=False, showticklabels=False)
898
899

    # Update the margins to add a title and see graph x-labels.
900
901
902
    # fig.layout.margin.update({'t':50, 'b':100})
    # fig.layout.update({'title': 'Individual Fitting summary'})
    fig.update_layout(template='plotly_white')
903
904
905
906
907
    # fig.layout.update({'height':800,'width':1000})

    return fig


908
909
910
def plot_indiv(mrs, res, ppmlim=(.2, 4.2)):

    colors = dict(data='rgb(67,67,67)',
911
                  pred='rgb(253,59,59)')
912
    line_size = dict(data=.5,
913
914
915
                     pred=2)

    ncols = 3
916
917
918
919
920
921
922
923
    nrows = int(np.ceil(mrs.numBasis / ncols))

    fig = make_subplots(rows=nrows, cols=ncols, subplot_titles=mrs.names)
    axis = mrs.getAxes()
    for i, metab in enumerate(mrs.names):
        c, r = i % ncols, i // ncols
        y_data = np.real(FID2Spec(mrs.FID))
        y_fit = np.real(FID2Spec(res.predictedFID(mrs, mode=metab, noBaseline=True)))
924
        # y_fit   = np.real(FID2Spec(pred(mrs,res,metab)))
925

926
        trace1 = go.Scatter(x=axis, y=y_data,
927
928
                            mode='lines',
                            line=dict(color=colors['data'], width=line_size['data']))
929
        trace2 = go.Scatter(x=axis, y=y_fit,
930
931
932
933
934
935
936
937
938
939
940
                            mode='lines',
                            line=dict(color=colors['pred'], width=line_size['pred']))
        fig.add_trace(trace1, row=r + 1, col=c + 1)
        fig.add_trace(trace2, row=r + 1, col=c + 1)

        fig.update_layout(template='plotly_white',
                          showlegend=False,
                          #   width = 1500,
                          height=1000,
                          font=dict(size=10),
                          title='Individual fits')
941
        for j in fig['layout']['annotations']:
942
943
            j['font'] = dict(size=10, color='#ff0000')

944
945
946
947
        if i == 0:
            xax = eval("fig.layout.xaxis")
            yax = eval("fig.layout.yaxis")
        else:
948
949
950
            xax = eval("fig.layout.xaxis{}".format(i + 1))
            yax = eval("fig.layout.yaxis{}".format(i + 1))
        xax.update(tick0=2, dtick=.5, range=[ppmlim[1], ppmlim[0]], showticklabels=False)
951
        yax.update(zeroline=True, zerolinewidth=1, zerolinecolor='Gray',
952
                   showgrid=False, showticklabels=False)
953
954
    return fig

Saad Jbabdi's avatar
Saad Jbabdi committed
955