import numpy as np
import matplotlib.pyplot as plt
import torch
[docs]
def draw(I,xI=None,fig=None,function=np.sum,**kwargs):
"""
Generate a plot of 'I' along the 3 cardinal planes - Coronal, sagittal, and transverse.
Parameters:
-----------
I : array
A 3D image volume
xI : list of array
A list of the coordinates along each dimension of I
fig : matplotlib.figure
A figure on which new plots will be generated
function : Python function
Default - np.sum(), will produce a Maximum Intensity Projection; The function used to generate the 3 views of I.
Returns:
--------
axs : np.array of matplotlib.axes
Each element of axs contains 1 of the 3 cardinal views of I
"""
if xI is None:
nI = I.shape[-3:]
xI = [np.arange(n) - (n-1)/2 for n in nI]
if fig is None:
fig = plt.figure()
# Initialize figure
dI = [x[1] - x[0] for x in xI]
I = np.asarray(I)
fig.clf()
axs = []
# Generate plot along coronal plane
ax = fig.add_subplot(3,1,1)
ax.imshow(function(I,-3).squeeze(),extent=(xI[-1][0]-dI[-1]/2, xI[-1][-1]+dI[-1]/2, xI[-2][-1]+dI[-2]/2, xI[-2][0]-dI[-2]/2),**kwargs)
axs.append(ax)
# Generate plot along ?transverse? plane
ax = fig.add_subplot(3,1,2)
ax.imshow(function(I,-2).squeeze(),extent=(xI[-1][0]-dI[-1]/2, xI[-1][-1]+dI[-1]/2, xI[-3][-1]+dI[-3]/2, xI[-3][0]-dI[-3]/2),**kwargs)
axs.append(ax)
# Generate plot along ?sagittal? plane
ax = fig.add_subplot(3,1,3)
ax.imshow(function(I,-1).squeeze(),extent=(xI[-2][0]-dI[-2]/2, xI[-2][-1]+dI[-2]/2, xI[-3][-1]+dI[-3]/2, xI[-3][0]-dI[-3]/2),**kwargs)
axs.append(ax)
# 12/03/25: Changed return from np.array(axs) to fig, axs. This was done so that a user could save the figure as a png file if desired
return fig, axs
[docs]
def getslice(I,ax):
"""
Return a 2D slice of the 3D image volume 'I'
Parameters:
-----------
I : array
A 3D image volume
ax : int
Options : -1, -2, -3; The axis of I from which a slice should be extracted
Returns:
--------
A subset of I along the desired axis
"""
if ax == -1:
return I[...,I.shape[-1]//2]
elif ax == -2:
return I[...,I.shape[-2]//2,:]
elif ax == -3:
return I[...,I.shape[-3]//2,:,:]
[docs]
def interp(xI,I,Xs,**kwargs):
"""
Interpolate ...
Parameters:
-----------
xI : list of array
A list of the coordinates along each dimension of I
I : array
A 3D image volume
Xs : ...
...
Returns:
--------
output : torch.Tensor
...
"""
Xs = Xs - torch.stack([x[0] for x in xI])
Xs = Xs / torch.stack([x[-1] - x[0] for x in xI])
Xs = Xs *2 - 1
return torch.nn.functional.grid_sample(I[None],Xs[None].flip(-1),align_corners=True,**kwargs)[0]
[docs]
def interp1d(xI,squish,Xs,dd,**kwargs):
"""
Hack for 1D interpolation
Parameters:
-----------
xI : list of array
A list of the coordinates along each dimension of I
squish : ...
...
Xs : ...
...
dd : dict
A Python dictionary with 2 keys \'device\' and \'dtype\', which specifies those 2 PyTorch paremeters for computation
Returns:
--------
output : ...
...
"""
# Set up a hack for 1d interplation
# grid sample supports 2D
# Mke it 2d + use the slice coordinate as the first coordinate, and zeros as the second fake coordinate
samples = torch.stack([Xs[...,0].squeeze(),torch.zeros_like(Xs[...,0].squeeze())],-1)
# for the input, we keep the channel dimension, keep the first coordinate, and add a fake second coordinate
squishin = squish[:,:,None]
out = interp([xI[0],torch.Tensor([-0.5,0.5],**dd)],squishin,samples.reshape(-1,1,2),**kwargs)
return out.reshape((squish.shape[0],)+samples.shape[:-1])
# we need integration of v, note there is NO batch dimension
[docs]
def phii_from_v(xv,v):
"""
...
Parameters:
-----------
xv : list of torch.Tensor
...
v : torch.Tensor
...
Returns:
--------
phii : torch.Tensor
...
"""
dt = 1.0/v.shape[0]
XV = torch.stack(torch.meshgrid(xv,indexing='ij'),-1)
phii = XV.clone()#.repeat((v.shape[0],1,1,1))
for t in range(v.shape[0]):
Xs = XV - v[t]*dt # Xs should have a batch dimension
phii = interp(xv,(phii-XV).permute(3,0,1,2),Xs,padding_mode='border').permute(1,2,3,0) + Xs
return phii
[docs]
def toblocks(Jd,blocksize):
"""
...
Parameters:
-----------
Jd : torch.Tensor
A 3D image volume with 1 additional batch dimension
blocksize : int
...
Returns:
--------
Jdpp : ...
...
"""
nblocks = torch.ceil(torch.Tensor(Jd.shape[1:],**dd)/blocksize ).to(int)
topad = nblocks*blocksize - torch.Tensor(Jd.shape[1:],device=device)
topadlist = [topad[-1],0,topad[-2],0,topad[-3],0] # this pads the left
Jdp = torch.nn.functional.pad(Jd,topadlist,mode='reflect')
Jdpv = Jdp.reshape(Jdp.shape[0],nblocks[0],blocksize,nblocks[1],blocksize,nblocks[2],blocksize)
Jdpp = Jdpv.permute(1,3,5,0,2,4,6)
return Jdpp
[docs]
def fromblocks(fphiIpp,Jdsize):
"""
...
Parameters:
-----------
fphiIpp : ...
...
Jdsize : torch.Tensor
The shape of the downsampled target data
Returns:
--------
output : ...
...
"""
# undo the permutation
blocksize = fphiIpp.shape[-1]
nblocks = torch.ceil(torch.Tensor(Jdsize[-3:],**dd)/blocksize ).to(int)
#print(nblocks)
topad = nblocks*blocksize - torch.Tensor(Jdsize[-3:],device=device)
fphiIpv = fphiIpp.permute(3,0,4,1,5,2,6)
# NOTE THIS SIZE 1 IS HARD CODED
fphiIp = fphiIpv.reshape(1,nblocks[0]*blocksize,nblocks[1]*blocksize,nblocks[2]*blocksize)
#return fphiIp[:,:Jdsize[-3],:Jdsize[-2],:Jdsize[-1]]
return fphiIp[:,topad[0]:,topad[1]:,topad[2]:]
[docs]
def measure_matching_dot(qIU,wIU,phiiQJU,wphiiQJU,SigmaQIU):
"""
...
Parameters:
-----------
qIU : ...
...
wIU : ...
...
phiiQJU : ...
...
wphiiQJU : ...
...
SigmaQIU : ...
...
Returns:
--------
output : ...
...
"""
K = torch.exp( - torch.sum( (qIU[:,None] - phiiQJU[None,:])**2/2/SigmaQIU , -1) )*wIU[:,None]*wphiiQJU[None,:]
return torch.sum(K)