Source code for slice_alignment

import argparse
import torch
# torch.concatenate=torch.cat # compatibility
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
# from os.path import join,basename,splitext
from scipy.interpolate import interpn
from pathlib import Path
import os
from slice_alignment_help import *
from IPython.display import display

[docs] def main(): """ Perform atlas-free slice alignment of a set of images Parameters: =========== outdir : str The directory where all intermediate and final outputs should be saved -fnames : list of str The list of all files to be registered in -npad : int TODO -down : int TODO -a : float TODO -p : float TODO -niter_big_loop : int TODO -niter_reg : int TODO -niter_atlas : int TODO -asquare : float TODO -asquare0 : float TODO -anisotropy_factor : float TODO -epT : float TODO -epL : float TODO -epv : float TODO -c : float TODO -sigmaM : float TODO -sigmaR : float TODO -a_reg : float TODO -device : str Default - cpu; The device where PyTorch computations should be performed -dtype : str Default - float; The dtype used for PyTorch computations --enable_deformation : bool TODO: NOT CURRENTLY SUPPORTED --remove_artifacts : bool If present, remove rows or columns containing exclusively 1s. This type of artifact is common in certain use cases. --saveAllFigs : bool If present, save all potential figures into \'outdir\' Output files: ============= A.npz : TODO TODO v.npz : TODO TODO Esave.npz : TODO TODO RphiI.npz : TODO TODO phiiRiJ.npz : TODO TODO Wshow.npz : TODO TODO W_robust_loss.npz : TODO TODO """ parser = argparse.ArgumentParser() parser.add_argument('outdir', type = Path, help = 'The directory where all intermediate and final outputs should be stored') parser.add_argument('-fnames', nargs = '+', type = str, help = 'List of file names in alignment order') parser.add_argument('-npad', default = 0, type = int, help = 'Default - 0; The number of slices to be padded') parser.add_argument('-down', default = 1, type = int, help = 'Default - 1; TODO - Add description') parser.add_argument('-a', default=10.0, type=float, help='Default - 10.0; Highpass operator for 2D registration') parser.add_argument('-p', default=2.0, type=float, help='Default - 2.0; Lowpass operator for 2D registration') parser.add_argument('-niter_big_loop', default = 400, type=int, help = 'Default - 400; The number of iterations to perform registration of the whole dataset') parser.add_argument('-niter_reg', default = 5, type=int, help = 'Default - 5; The number of iterations to perform registration on a subproblem') parser.add_argument('-niter_atlas', default = 5, type=int, help = 'Default - 5; The number of iterations to perform registration on a subproblem') parser.add_argument('-asquare', default = 0.25**2, type=float, help = 'Default - 0.25**2; A scalar used for registration') parser.add_argument('-asquare0', default = 3.5**2, type=float, help = 'Default - 3.5**2; A scalar used for registration') parser.add_argument('-anisotropy_factor', default = 0.01, type=float, help = 'Default - 0.01; A scalar used during registration') parser.add_argument('-epT', default = 0.0, type=float, help = 'Default - 0; A scalar used during registration') parser.add_argument('-epL', default = 0.0, type=float, help = 'Default - 0; A scalar used during registration') parser.add_argument('-epv', default = 500.0, type=float, help = 'Default - 500; A scalar used during registration') parser.add_argument('-c', default = 2.0, type=float, help = 'Default - 2; A scalar used during registration') parser.add_argument('-sigmaM', default = 1.0, type=float, help = 'Default - 1; A scalar used during registration') parser.add_argument('-sigmaR', default = 500.0, type=float, help = 'Default - 500; A scalar used during registration') parser.add_argument('-a_reg', default = 6.0, type=float, help = 'Default - 6; A scalar used during registration') parser.add_argument('-device', default = 'cpu', help = 'Default - cpu; The device where PyTorch computations should occur during registration') parser.add_argument('-dtype', default = float, help = 'Default - torch.float32; The dtype to be used during PyTorch computation') parser.add_argument('--enable_deformation', action = 'store_true', help = 'TODO: NOT CURRENTLY SUPPORTED') parser.add_argument('--remove_artifacts', action = 'store_true', help = 'Remove rows or columns containing exclusively 1s. This type of artifact is common in certain use cases.') parser.add_argument('--saveAllFigs', action = 'store_true', help = 'If present, save all potential figures into \'outdir\'') args = parser.parse_args() outdir = args.outdir fnames = args.fnames npad = args.npad down = args.down a = args.a p = args.p niter_big_loop = args.niter_big_loop niter_reg = args.niter_reg niter_atlas = args.niter_atlas asquare = args.asquare asquare0 = args.asquare0 anisotropy_factor = args.anisotropy_factor epT = args.epT epL = args.epL epv = args.epv c = args.c sigmaM = args.sigmaM sigmaR = args.sigmaR a_reg = args.a_reg device = args.device dtype = args.dtype enable_deformation = args.enable_deformation remove_artifacts = args.remove_artifacts saveAllFigs = args.saveAllFigs # Create outdir if it doesn't already exist if not os.path.exists(outdir): os.makedirs(outdir,exist_ok=True) # =============================================== # ===== (0) Load and clean the input images ===== # =============================================== # Load the files in sequential order J_ = [] W_ = [] for fname in fnames: Ji = plt.imread(fname) if Ji.dtype == np.uint8: Ji = Ji / 255.0 if Ji.shape[-1] == 4: Ji = Ji[...,:3] if remove_artifacts: # find any rows or colums that are all ones if Ji.ndim == 2: Ji = Ji[...,None].repeat(3,axis=-1) rowones = np.all(Ji>=0.95,(0,-1)) colones = np.all(Ji>=0.95,(1,-1)) Wi = (1-rowones[None,:])*(1-colones[:,None]) else: Wi = np.ones(Ji.shape[:2]) J_.append(Ji) W_.append(Wi) # Interpolate every image onto a grid of the same size nJ = [Ji.shape for Ji in J_] nJ = np.max(nJ,0) nJ = [len(J_),nJ[0],nJ[1]] x2d = [np.arange(n)*down - (n-1)*down/2 for n in nJ[1:]] X2d = np.stack(np.meshgrid(*x2d,indexing='ij'),-1) fig,ax = plt.subplots() # hfig = display(fig,display_id=True) J__ = [] W__ = [] for Ji,Wi in zip(J_,W_): x = [np.arange(n)*down - (n-1)*down/2 for n in Ji.shape[:2]] Ji_ = interpn(x,Ji,X2d,bounds_error=False,method='nearest') Wi_ = interpn(x,Wi,X2d,bounds_error=False,method='nearest') Wi_ = (1.0 - np.isnan(Ji_[...,0]))*Wi_ Ji_[np.isnan(Ji_)] = 0 Wi_[np.isnan(Wi_)] = 0 J__.append(Ji_) W__.append(Wi_) ax.cla() ax.imshow(Ji_) # hfig.update(fig) # Note our convention is to use J = np.stack(J__,0).transpose(-1,0,1,2) W = np.stack(W__) xJ = [np.arange(nJ[0])-(nJ[0]-1)/2,x2d[0],x2d[1]] # Optionally, pad the first and last slices if npad == 0: J = np.pad(J,((0,0),(npad,npad),(0,0),(0,0)),mode='reflect') W = np.pad(W,((npad,npad),(0,0),(0,0)),mode='reflect') nJ = J.shape[1:] xJ = [np.arange(nJ[0])-(nJ[0]-1)/2,x2d[0],x2d[1]] # Convert data to pytorch objects J = torch.tensor(J,dtype=dtype,device=device) W = torch.tensor(W,dtype=dtype,device=device) xJ = [torch.tensor(x,dtype=dtype,device=device) for x in xJ] # =================================================================== # ===== (1) Generate necessary data structures for registration ===== # =================================================================== # In order to apply a linear transform, we need to generate a sequence of 2D affine matrices XJ = torch.stack(torch.meshgrid(xJ,indexing='ij'),-1) # Pixel locations A = torch.eye(3) # A[0,0] = 1.1 # A[1,1] = 0.9 A = A[None].repeat(nJ[0],1,1) # Repeat, so there is an affine matrix for each slice A = A2DtoA3D(A) Ai = torch.linalg.inv(A) Ai = Ai.to(dtype=dtype, device=device) # Switch to torch.float32 XJ = XJ.to(dtype=dtype, device=device) # Switch to torch.float32 Xs = AX(Ai,XJ) AJ = interp(xJ,J,XJ) # Apply the rigid transforms to our data fig,ax = plt.subplots() ax.imshow(J[:,AJ.shape[1]//2].permute(1,2,0)) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig0.png')) fig,ax = plt.subplots() ax.imshow(AJ[:,AJ.shape[1]//2].permute(1,2,0)) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig1.png')) AJ = interp(xJ,J,Xs) fig,ax = plt.subplots() ax.imshow(AJ[:,AJ.shape[1]//2].permute(1,2,0)) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig2.png')) fig,ax = plt.subplots() ax.imshow(AJ[:,:,AJ.shape[2]//2].permute(1,2,0),aspect='auto') if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig3.png')) # get a set of sample points for v extendv = 1.1 # i.e. make it 10% bigger than the domain of J, to avoid wraparound dv = down*2 vmin1 = torch.amin(xJ[1]) vmin2 = torch.amin(xJ[2]) vmax1 = torch.amax(xJ[1]) vmax2 = torch.amax(xJ[2]) vc1 = (vmin1 + vmax1)/2 vc2 = (vmin2 + vmax2)/2 vr1 = (vmax1-vmin1)/2*extendv vr2 = (vmax2-vmin2)/2*extendv v1 = torch.arange(vc1-vr1,vc1+vr1,dv,device=device,dtype=dtype) v2 = torch.arange(vc2-vr2,vc2+vr2,dv,device=device,dtype=dtype) xv = [xJ[0],v1,v2] XV = torch.stack( torch.meshgrid(*xv,indexing='ij') , -1) XV2d = XV[...,1:] v2d = torch.zeros_like(XV2d) v2d = torch.randn(v2d.shape,dtype=v2d.dtype) # get highpass and lowpass operators for 2d reg L = L_from_xv_a_p(xv,a,p) LL = L**2 K = 1.0/LL fig,ax = plt.subplots() ax.imshow(L) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig4_L.png')) fig,ax = plt.subplots() ax.imshow(K) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig5_K.png')) v2d = torch.fft.ifftn( torch.fft.fftn(v2d,dim=(1,2))*K[...,None] , dim=(1,2),).real v3d = v2DToV3D(v2d) v3d /= torch.std(v3d) v3d *= 20 v = v2d # for later phi = exp(xv,v3d) fig,ax = plt.subplots() ax.contour(xv[2],xv[1],phi[phi.shape[0]//2,...,1]) ax.contour(xv[2],xv[1],phi[phi.shape[0]//2,...,2]) ax.set_title('Example deformation') if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig6_exdef.png')) # initial guess I = (torch.sum(J*W,1,keepdims=True)/(1e-6 + torch.sum(W,0,keepdims=True))).repeat(1,J.shape[1],1,1) xI = [x.clone() for x in xJ] fig,ax = plt.subplots() ax.imshow(I[:,I.shape[1]//2].permute(1,2,0)) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig7.png')) fig,ax = plt.subplots() ax.imshow(I[:,:,I.shape[2]//2].permute(1,2,0),aspect='auto') if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig8.png')) # transform an image with phi phiI = interp(xI,I,phi) fig,ax = plt.subplots() ax.imshow(phiI[:,AJ.shape[1]//2].permute(1,2,0)) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig9.png')) fig,ax = plt.subplots() ax.imshow(phiI[:,:,phiI.shape[2]//2].permute(1,2,0),aspect='auto') if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig10.png')) # Compute vnew A_temp = torch.eye(3)[None].repeat(J.shape[1],1,1) v_temp = torch.zeros_like(v) Anew,vnew,Eregistration,Ereg = weighted_see_registration(xI,I,xJ,J,W,xv,v_temp,A_temp,a,p,sigmaM=1.0,sigmaR=1e5,niter=10,epT=1e-2,epL=1e-6,epv=1e1,draw=5) # Get the jacobian weights Wdetjac = detjac(xv,vnew) fig,ax = plt.subplots() mappable = ax.imshow(Wdetjac[Wdetjac.shape[0]//2]) plt.colorbar(mappable) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig11.png')) fig,ax = plt.subplots() ax.imshow(W[W.shape[0]//2],interpolation='none') if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig12.png')) RphiI = transform_image(xI,I,xv,vnew,Anew,xJ) fig,ax = plt.subplots() ax.imshow(RphiI[:,:,RphiI.shape[2]//2].permute(1,2,0),aspect='auto',interpolation='none') if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig13.png')) L,WR = robust_loss(RphiI,xJ,J,W,c, return_weights=True) fig,ax = plt.subplots() ax.imshow(WR[WR.shape[0]//2]) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig14.png')) # get phiiRiJ phiiRiJ = inverse_transform_image(xJ,J,xv,vnew,Anew,xI,padding_mode='border') phiiRiW = inverse_transform_image(xJ,W[None]*WR,xv,vnew,Anew,xI,padding_mode='zeros',mode='nearest')[0] XI = torch.stack(torch.meshgrid(xI,indexing='ij'),-1) Wdetjacs = interp(xv,Wdetjac[None],XI)[0] fig,ax = plt.subplots() ax.imshow((phiiRiJ)[:,I.shape[1]//2].permute(1,2,0)) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig15.png')) fig,ax = plt.subplots() ax.imshow((phiiRiW)[I.shape[1]//2]) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig16.png')) Inew = I.clone() Inew,Eat,ERat = atlas_from_aligned_slices_and_weights(xI,Inew*0, dtype, device,phiiRiJ,phiiRiW*Wdetjacs,asquare=2.0**2,niter=2,draw=True,anisotropy_factor=1.0) fig,ax = plt.subplots() ax.imshow(Inew[:,I.shape[1]//2].permute(1,2,0)) if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig17.png')) fig,ax = plt.subplots() ax.imshow(Inew[:,:,I.shape[2]//2].permute(1,2,0),aspect='auto') if saveAllFigs: plt.savefig(os.path.join(outdir, 'fig18.png')) # let's test the FWHM fwhm = atlas_from_aligned_slices_and_weights(xI,Inew*0, dtype, device,phiiRiJ,phiiRiW*Wdetjacs,asquare=3.5**2,niter=2,draw=True,return_fwhm=True,anisotropy_factor=0.3**2) # ==================================== # ===== (2) Perform registration ===== # ==================================== fig_at,ax_at = plt.subplots(2,3) ax_at = ax_at.ravel() hfig_at = display(fig_at,display_id=True) fig_at_estimate = plt.figure() hfig_at_estimate = display(fig_at_estimate,display_id=True) fig_reg = plt.figure() hfig_reg = display(fig_reg,display_id=True) fig_E,ax_E = plt.subplots(1,1) if isinstance(ax_E,np.ndarray): ax_E = ax_E.ravel() else: ax_E = [ax_E] hfig_E = display(fig_E,display_id=True) # we want xI bigger than xJ so ther are no boundary issues bigger = 20 x2dI = [torch.arange(n+bigger,dtype=dtype)*down - (n+bigger-1)*down/2 for n in nJ[1:]] xI = [torch.arange(nJ[0],dtype=dtype)-(nJ[0]-1)/2,x2dI[0],x2dI[1]] XI = torch.stack(torch.meshgrid(xI,indexing='ij'),-1) # this is the loss we want to report, not WSEE loss Esave = [] v = torch.zeros_like(v) A = torch.eye(3) A = A[None].repeat(nJ[0],1,1) # initialize with mean I = torch.zeros((J.shape[0],XI.shape[0],XI.shape[1],XI.shape[2])) + (torch.sum(J*W,dim=(1,2,3))/torch.sum(W,dim=(0,1,2)))[...,None,None,None] # first get the loss and the weights, using current guesses RphiI = transform_image(xI,I,xv,v,A,xJ) rloss, W_robust_loss = robust_loss(RphiI,xJ,J,W,c, return_weights=True) for it_big_loop in range(niter_big_loop): if it_big_loop == 0: asquare = 4.0**2*asquare0 elif it_big_loop == 20: asquare = 2.0**2*asquare0 elif it_big_loop == 40: asquare = 1.0**2*asquare0 asquare = asquare0 # now register, but wait until I've estimated a reasonable atlas if it_big_loop > 0: A,v,Eregistration,Ereg = weighted_see_registration(xI,I,xJ,J,W*W_robust_loss,xv,v,A,a_reg,p,sigmaM,sigmaR,niter_reg,epT,epL,epv,draw=5,fig=fig_reg,hfig=hfig_reg) else: Ereg = 0.0 # and we want to add Ereg to our loss # now get jacobians Wdetjac = detjac(xv,v) # now update atlas phiiRiJ = inverse_transform_image(xJ,J,xv,v,A,xI,padding_mode='border',mode='nearest') phiiRiW = inverse_transform_image(xJ,W[None]*W_robust_loss,xv,v,A,xI,padding_mode='zeros',mode='nearest')[0] Wdetjacs = interp(xv,Wdetjac[None],XI)[0] I,Eat,ERat = atlas_from_aligned_slices_and_weights(xI,I, dtype, device,phiiRiJ,phiiRiW*Wdetjacs,asquare,niter=niter_atlas,fig=fig_at_estimate,hfig=hfig_at_estimate,draw=True,anisotropy_factor=anisotropy_factor) # and we want to add ERat to the loss # get the loss and the weights, using current guesses RphiI = transform_image(xI,I,xv,v,A,xJ) rloss, W_robust_loss = robust_loss(RphiI,xJ,J,W,c, return_weights=True) # this is the loss we want to report, it's the loss with the current parameters ax_at[0].cla() ax_at[0].imshow(I[:,I.shape[1]//2].permute(1,2,0)) ax_at[1].cla() ax_at[1].imshow(I[:,:,I.shape[2]//2].permute(1,2,0),aspect='auto',interpolation='none') ax_at[2].cla() ax_at[2].imshow(I[:,:,:,I.shape[3]//2].permute(1,2,0),aspect='auto',interpolation='none') Wshow = (phiiRiW*Wdetjacs) ax_at[3].cla() ax_at[3].imshow(Wshow[I.shape[1]//2]) ax_at[4].cla() ax_at[4].imshow(Wshow[:,I.shape[2]//2],aspect='auto',interpolation='none') ax_at[5].cla() ax_at[5].imshow(Wshow[:,:,I.shape[3]//2],aspect='auto',interpolation='none') # is this the right error? yes I think so Esave.append([rloss.item()+Ereg+ERat,rloss.item(),Ereg,ERat]) ax_E[0].cla() ax_E[0].plot(Esave) ax_E[0].legend(['total', 'robust matching', 'registration reg', 'atlas reg']) # hfig_at.update(fig_at) # hfig_E.update(fig_E) if saveAllFigs: fig_at_estimate.savefig(os.path.join(outdir,f'atlas_{it_big_loop:06d}.png')) np.savez(os.path.join(outdir,'A.npz'), data=A) np.savez(os.path.join(outdir,'v.npz'), data=v) np.savez(os.path.join(outdir,'Esave.npz'), data=Esave) np.savez(os.path.join(outdir,'RphiI.npz'), data=RphiI) np.savez(os.path.join(outdir,'phiiRiJ.npz'), data=phiiRiJ) np.savez(os.path.join(outdir,'Wshow.npz'), data=Wshow) np.savez(os.path.join(outdir,'W_robust_loss.npz'), data=W_robust_loss)
if __name__ == '__main__': main()