Commit 78712745 authored by William Clarke's avatar William Clarke
Browse files

Using autopep8 to get all formatting sorted.

parent 539a4f50
......@@ -3,4 +3,4 @@ __version__ = get_versions()['version']
del get_versions
# from fsl_mrs.core import MRS
# from fsl_mrs.core import MRSI
\ No newline at end of file
# from fsl_mrs.core import MRSI
from fsl_mrs.utils.dynamic.variable_mapping import VariableMapping
from fsl_mrs.utils.dynamic.dynmrs import dynMRS
\ No newline at end of file
from fsl_mrs.utils.dynamic.dynmrs import dynMRS
......@@ -94,7 +94,6 @@ class dynMRS(object):
raise (Exception(f'Unrecognised method {method}'))
res_list = self.collect_results(x, model, method, ppmlim, baseline_order)
if verbose:
print(f"Fitting completed in {time.time()-start_time} seconds.")
return {'x': x, 'resList': res_list}
......@@ -190,7 +189,6 @@ class dynMRS(object):
def dyn_loss_grad(self, x):
"""Add gradients across data list"""
g = []
mapped = self.vm.free_to_mapped(x)
LUT = self.vm.free_to_mapped(np.arange(self.vm.nfree), copy_only=True)
dfdx = 0
......@@ -247,4 +245,3 @@ class dynMRS(object):
results.loadResults(mrs, np.hstack(mapped[t]))
dynresList.append(results)
return dynresList
......@@ -79,8 +79,6 @@ class VariableMapping(object):
def __repr__(self) -> str:
return str(self)
def calc_nfree(self):
"""
Calculate number of free parameters based on mapped behaviour
......@@ -230,7 +228,7 @@ class VariableMapping(object):
elif (self.Parameters[name] == 'variable'): # copy one param for each time point
for t in range(self.ntimes):
mapped_params[t, index] = p[counter :counter + nmapped]
mapped_params[t, index] = p[counter:counter + nmapped]
counter += nmapped
else:
......@@ -248,7 +246,7 @@ class VariableMapping(object):
for t in range(self.ntimes):
mapped_params[t, index] = mapped[t, :]
else:
mapped = np.empty((self.ntimes, nmapped),dtype=object)
mapped = np.empty((self.ntimes, nmapped), dtype=object)
for i in range(nmapped):
params = p[counter:counter + nfree]
for t in range(self.ntimes):
......@@ -340,7 +338,7 @@ class VariableMapping(object):
return free_params
def get_gradient_fcn(self,param_name):
def get_gradient_fcn(self, param_name):
"""
Get the gradient function for a given parameter
Returns:
......
This diff is collapsed.
......@@ -144,8 +144,8 @@ def filter(mrs, FID, ppmlim, filter_type='bandpass'):
"""
# Sampling frequency (Hz)
fs = 1 / mrs.dwellTime
nyq = 0.5 * fs
fs = 1 / mrs.dwellTime
nyq = 0.5 * fs
f1 = np.abs(ppm2hz(mrs.centralFrequency, ppmlim[0]) / nyq)
f2 = np.abs(ppm2hz(mrs.centralFrequency, ppmlim[1]) / nyq)
......@@ -170,11 +170,11 @@ def ts_to_ts(old_ts, old_dt, new_dt, new_n):
new_dt: Output dwelltime
new_n: Output number of points
"""
old_n = old_ts.shape[0]
old_t = np.linspace(old_dt, old_dt * old_n, old_n) - old_dt
new_t = np.linspace(new_dt, new_dt * new_n, new_n) - new_dt
old_n = old_ts.shape[0]
old_t = np.linspace(old_dt, old_dt * old_n, old_n) - old_dt
new_t = np.linspace(new_dt, new_dt * new_n, new_n) - new_dt
f = interp1d(old_t, old_ts, axis=0)
f = interp1d(old_t, old_ts, axis=0)
new_ts = f(new_t)
return new_ts
......@@ -355,17 +355,18 @@ def calculate_lap_cov(x, f, data, sig2=None):
# Various utilities
def multiply(x,y):
def multiply(x, y):
"""
Elementwise multiply numpy arrays x and y
Elementwise multiply numpy arrays x and y
Returns same shape as x
"""
shape = x.shape
r = x.flatten()*y.flatten()
return np.reshape(r,shape)
r = x.flatten() * y.flatten()
return np.reshape(r, shape)
def shift_FID(mrs,FID,eps):
def shift_FID(mrs, FID, eps):
"""
Shift FID in spectral domain
......@@ -377,12 +378,13 @@ def shift_FID(mrs,FID,eps):
Returns:
array-like
"""
t = mrs.timeAxis
FID_shifted = multiply(FID,np.exp(-1j*2*np.pi*t*eps))
t = mrs.timeAxis
FID_shifted = multiply(FID, np.exp(-1j * 2 * np.pi * t * eps))
return FID_shifted
def blur_FID(mrs,FID,gamma):
def blur_FID(mrs, FID, gamma):
"""
Blur FID in spectral domain
......@@ -394,58 +396,60 @@ def blur_FID(mrs,FID,gamma):
Returns:
array-like
"""
t = mrs.timeAxis
FID_blurred = multiply(FID,np.exp(-t*gamma))
t = mrs.timeAxis
FID_blurred = multiply(FID, np.exp(-t * gamma))
return FID_blurred
def blur_FID_Voigt(mrs,FID,gamma,sigma):
def blur_FID_Voigt(mrs, FID, gamma, sigma):
"""
Blur FID in spectral domain
Parameters:
mrs : MRS object
FID : array-like
gamma : Lorentzian line broadening
sigma : Gaussian line broadening
gamma : Lorentzian line broadening
sigma : Gaussian line broadening
Returns:
array-like
"""
t = mrs.timeAxis
FID_blurred = multiply(FID,np.exp(-t*(gamma+t*sigma**2/2)))
FID_blurred = multiply(FID, np.exp(-t * (gamma + t * sigma**2 / 2)))
return FID_blurred
def rescale_FID(x,scale=100):
def rescale_FID(x, scale=100):
"""
Useful for ensuring values are within nice range
Forces norm of 1D arrays to be = scale
Forces norm of column-mean of 2D arrays to be = scale (i.e. preserves relative norms of the columns)
Parameters
----------
x : 1D or 2D array
scale : float
scale : float
"""
y = x.copy()
if type(y) is list:
factor = np.linalg.norm(sum(y)/len(y))
return [yy/factor*scale for yy in y],1/factor * scale
y = x.copy()
if isinstance(y, list):
factor = np.linalg.norm(sum(y) / len(y))
return [yy / factor * scale for yy in y], 1 / factor * scale
if y.ndim == 1:
factor = np.linalg.norm(y)
else:
factor = np.linalg.norm(np.mean(y,axis=1),axis=0)
y = y / factor * scale
return y,1/factor * scale
factor = np.linalg.norm(np.mean(y, axis=1), axis=0)
y = y / factor * scale
return y, 1 / factor * scale
def create_peak(mrs,ppm,amp,gamma=0,sigma=0):
def create_peak(mrs, ppm, amp, gamma=0, sigma=0):
"""
creates FID for peak at specific ppm
Parameters
----------
mrs : MRS object (contains time information)
......@@ -455,124 +459,126 @@ def create_peak(mrs,ppm,amp,gamma=0,sigma=0):
Peak Lorentzian dispersion
sigma : float
Peak Gaussian dispersion
Returns
-------
array-like FID
"""
if isinstance(ppm,(float,int)):
ppm = [float(ppm),]
if isinstance(amp,(float,int)):
amp = [float(amp),]
t = mrs.timeAxis
out = np.zeros(t.shape[0],dtype=np.complex128)
if isinstance(ppm, (float, int)):
ppm = [float(ppm), ]
if isinstance(amp, (float, int)):
amp = [float(amp), ]
t = mrs.timeAxis
out = np.zeros(t.shape[0], dtype=np.complex128)
for p, a in zip(ppm, amp):
freq = ppm2hz(mrs.centralFrequency, p)
x = a * np.exp(1j * 2 * np.pi * freq * t).flatten()
for p,a in zip(ppm,amp):
freq = ppm2hz(mrs.centralFrequency,p)
x = a*np.exp(1j*2*np.pi*freq*t).flatten()
if gamma>0 or sigma>0:
x = blur_FID_Voigt(mrs,x,gamma,sigma)
if gamma > 0 or sigma > 0:
x = blur_FID_Voigt(mrs, x, gamma, sigma)
# dephase
x = x*np.exp(-1j*np.angle(x[0]))
out+= x
x = x * np.exp(-1j * np.angle(x[0]))
out += x
return out
def extract_spectrum(mrs,FID,ppmlim=(0.2,4.2),shift=True):
def extract_spectrum(mrs, FID, ppmlim=(0.2, 4.2), shift=True):
"""
Extracts spectral interval
Parameters:
mrs : MRS object
FID : array-like
ppmlim : tuple
Returns:
array-like
"""
spec = FIDToSpec(FID)
first, last = mrs.ppmlim_to_range(ppmlim=ppmlim,shift=shift)
spec = spec[first:last]
spec = FIDToSpec(FID)
first, last = mrs.ppmlim_to_range(ppmlim=ppmlim, shift=shift)
spec = spec[first:last]
return spec
def normalise(x,axis=0):
def normalise(x, axis=0):
"""
Devides x by norm of x
"""
return x/np.linalg.norm(x,axis=axis)
return x / np.linalg.norm(x, axis=axis)
def ztransform(x,axis=0):
def ztransform(x, axis=0):
"""
Demeans x and make norm(x)=1
"""
return (x-np.mean(x,axis=axis))/np.std(x,axis)/np.sqrt(x.size)
def correlate(x,y):
return (x - np.mean(x, axis=axis)) / np.std(x, axis) / np.sqrt(x.size)
def correlate(x, y):
"""
Computes correlation between complex signals x and y
Uses formula : sum( conj(z(x))*z(y)) where z() is the ztransform
"""
return np.real(np.sum(np.conjugate(ztransform(x))*ztransform(y)))
return np.real(np.sum(np.conjugate(ztransform(x)) * ztransform(y)))
def phase_correct(mrs,FID,ppmlim=(1,3)):
def phase_correct(mrs, FID, ppmlim=(1, 3)):
"""
Apply phase correction to FID
"""
first,last = mrs.ppmlim_to_range(ppmlim)
phases = np.linspace(0,2*np.pi,1000)
first, last = mrs.ppmlim_to_range(ppmlim)
phases = np.linspace(0, 2 * np.pi, 1000)
x = []
for phase in phases:
f = np.real(np.fft.fft(FID*np.exp(1j*phase),axis=0))
x.append(np.sum(f[first:last]<0))
f = np.real(np.fft.fft(FID * np.exp(1j * phase), axis=0))
x.append(np.sum(f[first:last] < 0))
phase = phases[np.argmin(x)]
return FID*np.exp(1j*phase)
return FID * np.exp(1j * phase)
def detrend(data,deg=1,keep_mean=True):
def detrend(data, deg=1, keep_mean=True):
"""
remove polynomial trend from data
works along first dimension
"""
n = data.shape[0]
x = np.arange(n)
M = np.zeros((n,deg+1))
for i in range(deg+1):
M[:,i] = x**i
M = np.zeros((n, deg + 1))
for i in range(deg + 1):
M[:, i] = x**i
beta = np.linalg.pinv(M) @ data
pred = M @ beta
m = 0
if keep_mean:
m = np.mean(data,axis=0)
m = np.mean(data, axis=0)
return data - pred + m
def regress_out(x,conf,keep_mean=True):
def regress_out(x, conf, keep_mean=True):
"""
Linear deconfounding
"""
if type(conf) is list:
if isinstance(conf, list):
confa = np.squeeze(np.asarray(conf)).T
else:
confa = conf
if keep_mean:
m = np.mean(x,axis=0)
m = np.mean(x, axis=0)
else:
m = 0
return x - confa@(np.linalg.pinv(confa)@x) + m
return x - confa @ (np.linalg.pinv(confa) @ x) + m
def parse_metab_groups(mrs,metab_groups):
def parse_metab_groups(mrs, metab_groups):
"""
Creates list of indices per metabolite group
......@@ -580,84 +586,79 @@ def parse_metab_groups(mrs,metab_groups):
-----------
metab_groups :
- A single index : output is a list of 0's
- A single string : corresponding metab in own group
- A single string : corresponding metab in own group
- The strings 'separate_all' or 'combine_all'
- A list:
- list of integers : output same as input
- list of strings : each string is interpreted as metab name and has own group
Entries in the lists above can also be lists, in which case the corresponding metabs are grouped
mrs : MRS Object
Returns
-------
list of integers
"""
if isinstance(metab_groups,list) and len(metab_groups)==1:
if isinstance(metab_groups, list) and len(metab_groups) == 1:
metab_groups = metab_groups[0]
out = [0]*mrs.numBasis
if isinstance(metab_groups,int):
out = [0] * mrs.numBasis
if isinstance(metab_groups, int):
return out
if isinstance(metab_groups,str):
if isinstance(metab_groups, str):
if metab_groups.lower() == 'separate_all':
return list(range(mrs.numBasis))
if metab_groups.lower() == 'combine_all':
return [0]*mrs.numBasis
return [0] * mrs.numBasis
entry = metab_groups.split('+')
if isinstance(entry,str):
if isinstance(entry, str):
out[mrs.names.index(entry)] = 1
elif isinstance(entry,list):
elif isinstance(entry, list):
for n in entry:
assert(isinstance(n,str))
assert(isinstance(n, str))
out[mrs.names.index(n)] = 1
return out
if isinstance(metab_groups,list):
if isinstance(metab_groups[0],int):
if isinstance(metab_groups, list):
if isinstance(metab_groups[0], int):
assert(len(metab_groups) == mrs.numBasis)
return metab_groups
grpcounter = 0
for entry in metab_groups:
if isinstance(entry,str):
if isinstance(entry, str):
entry = entry.split('+')
grpcounter += 1
if isinstance(entry,str):
if isinstance(entry, str):
out[mrs.names.index(entry)] = grpcounter
elif isinstance(entry,list):
elif isinstance(entry, list):
for n in entry:
assert(isinstance(n,str))
assert(isinstance(n, str))
out[mrs.names.index(n)] = grpcounter
else:
raise(Exception('entry must be string or list of strings'))
m = min(out)
if m > 0:
out = [x-m for x in out]
return out
out = [x - m for x in out]
return out
# ----- MRSI stuff ---- #
def volume_to_list(data,mask):
def volume_to_list(data, mask):
"""
Turn voxels within mask into list of data
Parameters
----------
data : 4D array
mask : 3D array
......@@ -672,11 +673,12 @@ def volume_to_list(data,mask):
voxels = []
for x, y, z in it.product(range(nx), range(ny), range(nz)):
if mask[x, y, z]:
voxels.append((x, y, z))
voxels.append((x, y, z))
voxdata = [data[x, y, z, :] for (x, y, z) in voxels]
return voxdata
def list_to_volume(data_list,mask,dtype=float):
def list_to_volume(data_list, mask, dtype=float):
"""
Turn list of voxelwise data into 4D volume
......@@ -691,51 +693,53 @@ def list_to_volume(data_list,mask,dtype=float):
4D or 3D volume
"""
nx,ny,nz = mask.shape
nt = data_list[0].size
if nt>1:
data = np.zeros((nx,ny,nz,nt),dtype=dtype)
nx, ny, nz = mask.shape
nt = data_list[0].size
if nt > 1:
data = np.zeros((nx, ny, nz, nt), dtype=dtype)
else:
data = np.zeros((nx,ny,nz,),dtype=dtype)
i=0
data = np.zeros((nx, ny, nz,), dtype=dtype)
i = 0
for x, y, z in it.product(range(nx), range(ny), range(nz)):
if mask[x, y, z]:
if nt>1:
if mask[x, y, z]:
if nt > 1:
data[x, y, z, :] = data_list[i]
else:
data[x, y, z] = data_list[i]
i+=1
i += 1
return data
def unravel(idx,mask):
nx,ny,nz=mask.shape
def unravel(idx, mask):
nx, ny, nz = mask.shape
counter = 0
for x, y, z in it.product(range(nx), range(ny), range(nz)):
if mask[x,y,z]:
if counter==idx:
return np.array([x,y,z])
counter +=1
if mask[x, y, z]:
if counter == idx:
return np.array([x, y, z])
counter += 1
def ravel(arr,mask):
nx,ny,nz=mask.shape
def ravel(arr, mask):
nx, ny, nz = mask.shape
counter = 0
for x, y, z in it.product(range(nx), range(ny), range(nz)):
if mask[x,y,z]:
if arr==[x,y,z]:
if mask[x, y, z]:
if arr == [x, y, z]:
return counter
counter += 1
#### FMRS Stuff
# FMRS Stuff
def smooth_FIDs(FIDlist,window):
def smooth_FIDs(FIDlist, window):
"""
Smooth a list of FIDs (makes sense if acquired one after the other as the smoothing is done along the "time" dimension
Smooth a list of FIDs
(makes sense if acquired one after the other as the smoothing is done along the "time" dimension)
Note: at the edge of the list of FIDs the smoothing wraps around the list so make sure that the beginning and the end are 'compatible'
Note: at the edge of the list of FIDs the smoothing wraps around the
list so make sure that the beginning and the end are 'compatible'.
Parameters:
-----------
......@@ -747,12 +751,12 @@ def smooth_FIDs(FIDlist,window):
list of FIDs
"""
sFIDlist = []