Source code for spine_reg_pipeline

import numpy as np
import matplotlib.pyplot as plt
import os
import sys

import argparse
from spine_reg import * # TODO: Add path argument in final version
import gc

[docs] def main(): """ Command Line Arguments: ======================= fname_I : str The path to the interpolated atlas .npz file to be used for registration fname_L : str The path to the interpolated atlas labels .npz file to be used for registration fname_J : str The path to the spine reflection .npz file to be registered fname_pointsJ : str The path to an .swc file used for more precise registration of J outdir : str The directory where all output files should be saved e_path : str The location of the custom Python library \'emlddmm\', which can be cloned from GitHub at https://github.com/twardlab/emlddmm -device : str Default - cpu; The device where registration computations should occur -dtype : torch.dtype Default - torch.float32; The data type used for the voxel values in image registration -niter : int Default - 5000; The number of iterations to use for image registration -down : list of int Default - [8,8,8]; The factor used to downsample along each axis of I, respectively -blocksize : int Default - 50; TODO: ... -verbose : bool Default - False; If present, print out several progress messages throughout the registration process -saveAllFigs : bool Default - False; If present, save every intermediate figure generated before, during, and after registration -saveIntermediateFigs : bool Default - false; If present, save intermediate figures at each iteration of the registration process -saveFig0 : bool Default - False; If present, save a MIP of the interpolated atlas provided -saveFig1 : bool Default - False; If present, save a MIP of the interpolated atlas labels provided -saveFig2 : bool Default - False; If present, save a scatter plot of all the points labeled in the interpolated atlas porvided -saveFig3 : bool Default - False; If present, save a standard view of the target data along all 3 cardinal axes -saveFig4 : bool Default - False; If present, save a MIP of the target data along all 3 cardinal axes -saveFig5 : bool Default - False; If present, save a standard view of the target data along all 3 cardinal axes -saveFig6 : bool Default - False; If present, save a standard view of the target data along all 3 cardinal axes -saveFig7 : bool Default - False; If present, save a save a scatter plot of all points along the AP axis -saveFig8 : bool Default - False; If present, save a standard view of the qJU object along the AP axis -saveFig9 : bool Default - False; If present, save a save a standard view of the qJU object along the AP axis with points superimposed -saveFig10 : bool Default - False; If present, save a save a standard view of the downsampled target data along all 3 cardinal axes -saveFig11 : bool Default - False; If present, save a save a plot of the xv*xId object -saveFig12 : bool Default - False; If present, save a save a plot of the B object -saveFig13 : bool Default - False; If present, save a color gradient -saveFig14 : bool Default - False; If present, save a color gradient -saveFig15 : bool Default - False; If present, save a color gradient -saveFig16 : bool Default - False; If present, save a color gradient Raises: ======= Exception If any of the 8 input files do not have the correct file extension Exception If a negative Jacobian arises during the registration loop """ parser = argparse.ArgumentParser() parser.add_argument('fname_I', type=str, help = 'The path to the interpolated atlas .npz file') parser.add_argument('fname_L', type=str, help = 'The path to the interpolated atlas labels .npz file') parser.add_argument('fname_J', type=str, help = 'The path to the spine reflection .npz file') parser.add_argument('fname_pointsJ', type=str, help = 'The input .swc file used for more precise registration') parser.add_argument('outdir', type=str, help = 'The directory where all output files should be saved') parser.add_argument('e_path', type=str, help= 'The location of the custom Python library \'emlddmm\', which can be cloned from GitHub at https://github.com/twardlab/emlddmm') parser.add_argument('-device', type=str, default = 'cpu', help = 'Default - cpu; The device where registration computations should occur') parser.add_argument('-dtype', type = torch.dtype, default = torch.float32, help = 'Default - torch.float32; The data type used for image registration') parser.add_argument('-niter', type = int, default = 5000, help = 'Default - 5000; The number of iterations to use for image registration') parser.add_argument('-down', type = int, nargs = 3, default = [8,8,8], help = 'The factor used to downsample along each axis of I, respectively') parser.add_argument('-bs', '--blocksize', type=int, default = 50, help = 'TODO: ...') parser.add_argument('-v', '--verbose', action = 'store_true', help = 'Default - False; If present, print out several progress messages throughout the registration process') parser.add_argument('-saveAllFigs', action = 'store_true', help = 'Default - False; If present, save every intermediate figure generated before, during, and after registration') parser.add_argument('-saveIntermediateFigs', action = 'store_true', help = 'Default - false; If present, save intermediate figures at each iteration of the registration process') parser.add_argument('-saveFig0', action = 'store_true', help = 'Default - False; If present, save a MIP of the interpolated atlas provided') parser.add_argument('-saveFig1', action = 'store_true', help = 'Default - False; If present, save a MIP of the interpolated atlas labels provided') parser.add_argument('-saveFig2', action = 'store_true', help = 'Default - False; If present, save a scatter plot of all the points labeled in the interpolated atlas porvided') parser.add_argument('-saveFig3', action = 'store_true', help = 'Default - False; If present, save a standard view of the target data along all 3 cardinal axes') parser.add_argument('-saveFig4', action = 'store_true', help = 'Default - False; If present, save a MIP of the target data along all 3 cardinal axes') parser.add_argument('-saveFig5', action = 'store_true', help = 'Default - False; If present, save a standard view of the target data along all 3 cardinal axes') parser.add_argument('-saveFig6', action = 'store_true', help = 'Default - False; If present, save a standard view of the target data along all 3 cardinal axes') parser.add_argument('-saveFig7', action = 'store_true', help = 'Default - False; If present, save a scatter plot of all points along the AP axis') parser.add_argument('-saveFig8', action = 'store_true', help = 'Default - False; If present, save a standard view of the qJU object along the AP axis') parser.add_argument('-saveFig9', action = 'store_true', help = 'Default - False; If present, save a standard view of the qJU object along the AP axis with points superimposed') parser.add_argument('-saveFig10', action = 'store_true', help = 'Default - False; If present, save a standard view of the downsampled target data along all 3 cardinal axes') parser.add_argument('-saveFig11', action = 'store_true', help = 'Default - False; If present, save a plot of the xv*xId object') parser.add_argument('-saveFig12', action = 'store_true', help = 'Default - False; If present, save a plot of the B object') parser.add_argument('-saveFig13', action = 'store_true', help = 'Default - False; If present, save a color gradient') parser.add_argument('-saveFig14', action = 'store_true', help = 'Default - False; If present, save a color gradient') parser.add_argument('-saveFig15', action = 'store_true', help = 'Default - False; If present, save a color gradient') parser.add_argument('-saveFig16', action = 'store_true', help = 'Default - False; If present, save a color gradient') args = parser.parse_args() fname_I = args.fname_I fname_L = args.fname_L fname_J = args.fname_J fname_pointsJ = args.fname_pointsJ outdir = args.outdir e_path = args.e_path device = args.device dtype = args.dtype niter = args.niter down = args.down blocksize = args.blocksize verbose = args.verbose saveAllFigs = args.saveAllFigs saveIntermediateFigs = args.saveIntermediateFigs saveFig0 = args.saveFig0 saveFig1 = args.saveFig1 saveFig2 = args.saveFig2 saveFig3 = args.saveFig3 saveFig4 = args.saveFig4 saveFig5 = args.saveFig5 saveFig6 = args.saveFig6 saveFig7 = args.saveFig7 saveFig8 = args.saveFig8 saveFig9 = args.saveFig9 saveFig10 = args.saveFig10 saveFig11 = args.saveFig11 saveFig12 = args.saveFig12 saveFig13 = args.saveFig13 saveFig14 = args.saveFig14 saveFig15 = args.saveFig15 saveFig16 = args.saveFig16 sys.path.append(e_path) import emlddmm # Create outdir if it doesn't already exist if not os.path.exists(outdir): os.makedirs(outdir,exist_ok=True) # ================================================= # ===== Perform validity checks on parameters ===== # ================================================= if '.npz' not in fname_I: raise Exception(f'{fname_I} should be a .npz file') if '.npz' not in fname_L: raise Exception(f'{fname_L} should be a .npz file') if '.npz' not in fname_J: raise Exception(f'{fname_J} should be a .npz file') if '.swc' not in fname_pointsJ: raise Exception(f'{fname_pointsJ} should be a .swc file') # =================================== # ===== Load interpolated atlas ===== # =================================== data_I = np.load(fname_I,allow_pickle=True) I = data_I['I'] xI = data_I['xI'] I = I / I.max() # Normalize atlas values lowert = 0.15 # This is an atlas-dependent parameter - locally found at /nafs/dtward/spine_work/interpolated_atlas.npz M = I<lowert I[M] = 1.0 I = (1 - I) I = I - lowert I[I<0] = 0 I = I/I.max() if saveAllFigs or saveFig0: fig, axs = draw(I,xI,function=getslice) fig.savefig(os.path.join(outdir, 'fig0_interp_atlas.png')) if verbose: print('Successfully loaded atlas file . . .') # ========================================== # ===== Load interpolated atlas labels ===== # ========================================== data_L = np.load(fname_L,allow_pickle=True) L = data_L['I']%256 xL = data_L['xI'] if saveAllFigs or saveFig1: fig, axs = draw((L%256)==16,xI,) fig.savefig(os.path.join(outdir, 'fig1_interp_atlas_labels.png')) qIU = np.stack(np.meshgrid(*xI,indexing='ij'),-1)[(L[0]%256)==16] nqU = 1000 # TODO: Decide if this should be a passable argument qIU = qIU[np.random.permutation(qIU.shape[0])[:nqU]] if saveAllFigs or saveFig2: fig = plt.figure() ax = fig.add_subplot(projection='3d') ax.scatter(*qIU.T) fig.savefig(os.path.join(outdir, 'fig2_scattered_labels.png')) SigmaQIU = [] for j in range(3): d2 = (qIU[:,None,j] - qIU[None,:,j])**2 d2i = [] for i in range(d2.shape[0]): d2i.append( np.min( d2[i][d2[i]>0] ) ) SigmaQIU.append( np.mean(d2i) ) SigmaQIU = np.array(SigmaQIU) if verbose: print('Successfully loaded atlas labels file . . .') # ===================================================================== # ===== Load the target image data to be registered + axis values ===== # ===================================================================== try : data_J = np.load(fname_J, allow_pickle = True) J = data_J['I'] xJ = data_J['xI'] except : J = np.load(fname_J.replace('.npz','_I.npy')) xJ = [ np.load(fname_J.replace('.npz','_xI0.npy')), np.load(fname_J.replace('.npz','_xI1.npy')), np.load(fname_J.replace('.npz','_xI2.npy')), ] # If necessary, add a 4th dimension to J if J.ndim == 3: J = J[None] # Readjust coordinate axes to be centered about 0 xJ = [x - np.mean(x) for x in xJ] # Normalize J Jsave = J.copy() low = 0 # TODO: ADD AS ARG high = 1500 # TODO: ADD AS ARG J = Jsave.copy().clip(low,high) WJ = 1.0 - (J == high) J = J - low J = J / (high-low) J = J*WJ WJ = WJ*0+1 sl = (slice(None),slice(24,-50,None),slice(50,-110,None),slice(200,-310,None)) if saveAllFigs or saveFig3: fig, axs = draw((J*WJ)[sl],[x[s] for s,x in zip(sl[1:],xJ)],function=getslice) fig.savefig(os.path.join(outdir, 'fig3_target_slice0.png')) if saveAllFigs or saveFig4: fig, axs = draw((J*WJ)[sl],[x[s] for s,x in zip(sl[1:],xJ)]) fig.savefig(os.path.join(outdir, 'fig4_target_MIP.png')) if saveAllFigs or saveFig5: fig, axs = draw(J*WJ,xJ,function=getslice) fig.savefig(os.path.join(outdir, 'fig5_target_slice1.png')) if saveAllFigs or saveFig6: fig, axs = draw(WJ,xJ,function=getslice) fig.savefig(os.path.join(outdir, 'fig6_target_slice2.png')) if verbose: print('Successfully loaded target image data and axes files . . .') # Initialzie qJU, which is a set of unlabeled points associated with the target image qJU = [] with open(fname_pointsJ) as f: for line in f: if line.strip()[0] == '#': continue items = line.split()[2:5] qi = np.array([float(s) for s in items]) qJU.append(qi) qJU = np.stack(qJU) qJU = qJU[np.random.randint(low=0,high=qJU.shape[0],size=nqU)] qJU = qJU + np.random.randn(*qJU.shape)*2 if saveAllFigs or saveFig7: fig = plt.figure() ax = fig.add_subplot(projection='3d') ax.scatter(*qJU.T) ax.set_aspect('equal') fig.savefig(os.path.join(outdir, 'fig7_qJU_scatter.png')) if saveAllFigs or saveFig8: fig,ax = plt.subplots() ax.imshow(J.sum((0,1)),extent=(xJ[2][0],xJ[2][-1],xJ[1][-1],xJ[1][0])) ax.plot(qJU[:,2],qJU[:,1]) fig.savefig(os.path.join(outdir, 'fig8_qJU0.png')) if saveAllFigs or saveFig9: fig,ax = plt.subplots() ax.imshow(J[0,:,J.shape[2]//2],extent=(xJ[2][0],xJ[2][-1],xJ[0][-1],xJ[0][0])) ax.plot(qJU[:,2],qJU[:,0]) fig.savefig(os.path.join(outdir, 'fig9_qJU1.png')) if verbose: print('Successfully loaded target image points file . . .') # ========================= # ===== Preprocessing ===== # ========================= # Downsample the atlas image, I, and target iamge, J xId,Id = emlddmm.downsample_image_domain(xI,I,down) xJd,Jd,WJd = emlddmm.downsample_image_domain([x[s] for x,s in zip(xJ,sl[1:])],J[sl],down,W=WJ[0][sl[1:]]) if saveAllFigs or saveFig10: fig, axs = draw(Jd*WJd,xJd,function=getslice) fig.savefig(os.path.join(outdir, 'fig10_target_down.png')) # Convert input data from Numpy to torch format dd = {'device':device,'dtype':dtype} Id = torch.tensor(Id,**dd) Jd = torch.tensor(Jd,**dd) xId = [torch.tensor(x,**dd) for x in xId] xJd = [torch.tensor(x,**dd) for x in xJd] WJd = torch.tensor(WJd,**dd) XId = torch.stack(torch.meshgrid(xId,indexing='ij'),-1) XJd = torch.stack(torch.meshgrid(xJd,indexing='ij'),-1) nId = torch.tensor(Id.shape[-3:],device=device) # int qIU = torch.tensor(qIU,**dd) qJU = torch.tensor(qJU,**dd) SigmaQIU = torch.tensor(SigmaQIU,**dd) # Initialize parameters for deformation portion of registration # TODO: Add nt, a, and p as arguments uaing the below values for default nt = 5 dt = 1.0/nt a = 200.0 p = 2.0 expand = [1.02,1.2,1.2] dv = a*0.5 vminmax = [(torch.min(x),torch.max(x)) for x in xId] vr = torch.stack( [(x[1] - x[0]) for x in vminmax] ) vc = torch.stack( [(x[1]+x[0])/2 for x in vminmax] ) vminmax = vc + (vr*torch.tensor(expand,**dd))*torch.tensor([-1,1],**dd)[...,None]*0.5 xv = [torch.arange(vm[0],vm[1],dv) for vm in vminmax.T] XV = torch.stack(torch.meshgrid(xv,indexing='ij'),-1) nv = [len(x) for x in xv] fv = [torch.arange(n)/n/dv for n in XV.shape[:3]] FV = torch.stack(torch.meshgrid(fv,indexing='ij'),-1) LL = (1.0 + 2.0*a**2 * torch.sum( (1.0 - torch.cos(2.0*np.pi*FV*dv))/dv**2, -1))**(2.0*p) K = 1.0/LL v = torch.zeros(nt,*nv,3,requires_grad=True) theta = torch.zeros(nv[0],requires_grad=True,**dd) T = torch.zeros((nv[0],2),requires_grad=True,**dd) squish = torch.ones(nv[0],**dd)*(-0.4) # exponentiate it (-0.1 was good), -1 makes it very short and fat in the coronal plane squish.requires_grad = True B = torch.exp( -(xv[0][None,:] - xId[0][:,None])**2/2/5000**2) B = torch.exp( -(xv[0][None,:] - xId[0][:,None])**2/2/2000**2) B = B / torch.sum(B,1,keepdims=True) squishb = B@squish if saveAllFigs or saveFig11: fig,ax = plt.subplots() ax.plot(xv[0],squish.detach()) ax.plot(xId[0],squishb.detach()) fig.savefig(os.path.join(outdir, 'fig11_xv_xId.png')) if saveAllFigs or saveFig12: fig,ax = plt.subplots() ax.imshow(B) fig.savefig(os.path.join(outdir, 'fig12_B.png')) # Initalize the final permutation P = torch.eye(4)[[1,2,0,3]] A = P A.requires_grad = True stretch = torch.tensor([0.2,-0.2,-0.75],requires_grad = True) # the third number, in the coronal plane, makes it grow left right if it is positive # Define a metric for affine, since coordinate system is centered at 0, my off diagonal terms will be 0 in the standard push forward approach # for two basis matrices, we need to evaluate: # gij = int (Eix)^T Ejx dx = int trace [ x^T Ei^T Ej x ] dx = int trace [xx^T Ei^TEj]dx = trace [ int xx^T dx Ei^TEj ] XX = torch.sum(XId[...,None]*XId[...,None,:],(0,1,2)) O = torch.sum(torch.ones_like(XId[...,0])) XXO = torch.diag(torch.concatenate((torch.diag(XX),O[None]))) g = torch.zeros((12,12),**dd) count = 0 for i in range(3): for j in range(4): Eij = (torch.arange(4,**dd)==i)[...,None]*(torch.arange(4,**dd)==j)[...,None,:]*1.0 count_ = 0 for i_ in range(3): for j_ in range(4): Eij_ = (torch.arange(4,**dd)==i_)[...,None]*(torch.arange(4,**dd)==j_)[...,None,:]*1.0 g[count,count_] = torch.trace( XXO@(Eij.T@Eij_) ) count_ += 1 count += 1 gi = torch.linalg.inv(g) dJd = [(x[1] - x[0]).item() for x in xJd] # ======================== # ===== Registration ===== # ======================== figE,axE = plt.subplots(3,4) axE = axE.ravel() figI = plt.figure() figErr = plt.figure() figJ = plt.figure() draw(Jd.detach(),xJd,function=getslice,fig=figJ,aspect='auto') figJ.canvas.draw() figQ = plt.figure() axQ = figQ.add_subplot(projection='3d') axQ.cla() axQ.scatter(*qIU.T,label='qIU',alpha=0.1) axQ.legend() # Initialize output figures data structures for storing output data nrow = 6 ncol = 4 figS,axS = plt.subplots(nrow,ncol) figS.subplots_adjust(left=0,right=1,bottom=0,top=1,hspace=0,wspace=0) axS = axS.ravel() Esave = [] Tsave = [] # the max squishsave = [] thetasave = [] vsave = [] ALsave = [] ATsave = [] update_v = False vstart = -1 # Initalize various step sizes # TODO: All of these aside from measureMatchingSigma should be included epv = 5e4 eptheta = 5e-4 epSquish = 2e-4 epT = 2e2 measureMatchingSigma = SigmaQIU*5**2 sigmaM = (1e5)**0.5 sigmaR = 1e5 sigmaR = 2e5 sigmaQU = 1e-1*50 wIU = torch.ones_like(qIU[...,0])#/qIU.shape[0]*qJU.shape[0] EQU0 = measure_matching_dot(qIU,wIU,qIU,wIU,measureMatchingSigma) # only compute once # Begin the registration loop over 'niter' iterations for it in range(0,niter): if it > vstart: update_v = True if it < 500: blocksize = 0 elif it < 1000: blocksize = 50 else: blocksize = 32 # Clean up memory using garbage collection Xs = None Xs0 = None phii = None out = None gc.collect() # Apply blurs to barious data structures thetab = B@theta Tb = B@T squishb = B@squish # First, compute the inverse affine (Ai) Ai = torch.linalg.inv(A)[:3] Ai = torch.diag((-stretch).exp())@Ai Xs = (Ai[:3,:3]@XJd[...,None])[...,0] + Ai[:3,-1] # Second, compute the squish tosample = torch.concatenate((Tb.T,squishb[None],thetab[None])) out = interp1d(xId,tosample,Xs,dd) Ts = out[0:2].permute(1,2,3,0) squishs = out[2] thetas = out[3] eye = torch.diag(torch.ones(3,**dd),)[None,None,None].repeat(Ts.shape[0],Ts.shape[1],Ts.shape[2],1,1) zo = torch.tensor([0.0,0.0,0.0,1.0],**dd)[None,None,None,None].repeat(Ts.shape[0],Ts.shape[1],Ts.shape[2],1,1) z = torch.zeros_like(Ts[...,0,None],) Tcat = torch.concatenate((z,Ts),-1) squishmat = torch.diag_embed( torch.stack([torch.ones_like(squishs),(-squishs).exp(),(squishs).exp()] ,-1 ) ) rotmat = torch.stack([ torch.stack([torch.ones_like(thetas),torch.zeros_like(thetas),torch.zeros_like(thetas)],-1), torch.stack([torch.zeros_like(thetas),torch.cos(thetas),torch.sin(thetas)],-1), torch.stack([torch.zeros_like(thetas),-torch.sin(thetas),torch.cos(thetas)],-1) ],-2) Xs = Xs - Tcat Xs = (rotmat@squishmat@Xs[...,None])[...,0] # Third, compute the diffeo phii = phii_from_v(xv,v) Xs = interp(xv,(phii-XV).permute(-1,0,1,2),Xs).permute(1,2,3,0) + Xs phiI = interp(xId,Id,Xs) phiiQJU = interp(xJd,Xs.permute(-1,0,1,2),qJU[None,None])[:,0,0].T dphii = torch.ones_like(phiiQJU[...,0]) if torch.any(dphii)<=0: raise Exception('Negative jacobian') # contrast if blocksize > 0: Jdpp = toblocks(Jd,blocksize) # this only needs to be done once WJdpp = toblocks(WJd[None],blocksize) # only needs to be done once phiIpp = toblocks(phiI,blocksize) # only needs to be done once with torch.no_grad(): if blocksize == 0: muI = torch.sum(phiI*WJd)/torch.sum(WJd) muJ = torch.sum(Jd*WJd)/torch.sum(WJd) varI = torch.sum((phiI-muI)**2*WJd)/torch.sum(WJd) covIJ = torch.sum((phiI-muI)*(Jd-muJ)*WJd)/torch.sum(WJd) else: muI = torch.sum(phiIpp*WJdpp,(-3,-2,-1),keepdims=True)/torch.sum(WJdpp,(-3,-2,-1),keepdims=True) muJ = torch.sum(Jdpp*WJdpp,(-3,-2,-1),keepdims=True)/torch.sum(WJdpp,(-3,-2,-1),keepdims=True) varI = torch.sum((phiIpp-muI)**2*WJdpp,(-3,-2,-1),keepdims=True)/torch.sum(WJdpp,(-3,-2,-1),keepdims=True) covIJ = torch.sum((phiIpp-muI)*(Jdpp-muJ)*WJdpp,(-3,-2,-1),keepdims=True)/torch.sum(WJdpp,(-3,-2,-1),keepdims=True) if blocksize > 0: fphiIpp = (phiIpp-muI)/(varI + 1e-6)*covIJ + muJ fphiI = fromblocks(fphiIpp,Jd.shape) else: fphiI = (phiI-muI)/(varI + 1e-6)*covIJ + muJ err = (fphiI-Jd) EM = (err**2*WJd).sum()/2.0*torch.prod(torch.stack([x[1] - x[0 ] for x in xJd])) /sigmaM**2 ER = torch.sum(torch.sum(torch.abs(torch.fft.fftn(v,dim=(-4,-3,-2)))**2,(0,-1))*K)/2.0/nt/v[0,:,:,:,0].numel()/sigmaR**2*torch.prod(torch.stack([x[1] - x[0 ] for x in xJd])) # Point matching EQU = (EQU0 + -2*measure_matching_dot(qIU,wIU,phiiQJU,dphii,measureMatchingSigma) + measure_matching_dot(phiiQJU,dphii,phiiQJU,dphii,measureMatchingSigma))/2.0/sigmaQU**2 E = EM + ER + EQU E.backward() Esave.append([E.item(),EM.item(),ER.item(),EQU.item()]) # Draw the error figures draw(err.detach()*WJd,xJd,function=getslice,fig=figErr,aspect='auto') figErr.canvas.draw() draw(fphiI.detach(),xJd,function=getslice,fig=figI,aspect='auto') figI.canvas.draw() # Update parameters + set gradients to 0 theta.data = theta.data - theta.grad*eptheta theta.grad.zero_() T.data = T.data - T.grad*epT T.grad.zero_() squish.data = squish.data - squish.grad*epSquish squish.grad.zero_() epStretch = 0 stretch.data = stretch.data - stretch.grad*epStretch stretch.grad.zero_() if update_v: v.data = v.data - torch.fft.ifftn((torch.fft.fftn(v.grad,dim=(1,2,3))*K[...,None]),dim=(1,2,3)).real * epv v.grad.zero_() epA = 2e6 Agrad = (gi@A.grad[:3].ravel()).reshape(3,4) A.data[:3] = A.data[:3] - epA*Agrad A.grad.zero_() ALsave.append(A[:3,:3].clone().detach().ravel().numpy()) ATsave.append(A[:3,-1].clone().detach().ravel().numpy()) # Generate plots showing all relevant registration variables axE[0].cla() axE[0].plot(Esave) thetasave.append(thetab.clone().detach().abs().max()*180/np.pi) Tsave.append(Tb.clone().detach().abs().max()) squishsave.append(squishb.clone().detach().abs().max()) vsave.append(v.detach().abs().max()) axE[1].cla() axE[1].plot(Tsave) axE[1].set_title("max T") axE[2].cla() axE[2].plot(thetasave) axE[2].set_title("max theta") axE[3].cla() axE[3].plot(squishsave) axE[3].set_title("max squish") axE[4].cla() axE[4].plot(vsave) axE[4].set_title("max v") axE[5].cla() axE[5].plot(ALsave) axE[5].set_title("AL") axE[6].cla() axE[6].plot(ATsave) axE[6].set_title("AT") axE[7].cla() axE[7].plot(Tb.detach()) axE[7].set_title("Tb") axE[8].cla() axE[8].plot(thetab.detach()) axE[8].set_title("thetab") axE[9].cla() axE[9].plot(squishb.detach()) axE[9].set_title("squishb") figE.canvas.draw() axQ.cla() axQ.scatter(*qIU.T,label='qIU',alpha=0.1) axQ.scatter(*phiiQJU.detach().cpu().T,label='phiiQJU',alpha=0.1) axQ.legend() figQ.canvas.draw() nshow = nrow*ncol slices = np.round(np.linspace(0,Jd.shape[-1]-1,nshow+2)).astype(int) for i in range(nshow): Jshow = (Jd[0,:,:,slices[i]]*WJd[:,:,slices[i]]).numpy() Jshow = Jshow/np.max(Jshow) Ishow = fphiI[0,:,:,slices[i]].detach().numpy()/np.max(Jshow) axS[i].imshow(np.stack((Jshow,Ishow,Jshow),-1)) axS[i].axis('off') figS.subplots_adjust(left=0,right=1,bottom=0,top=1,hspace=0,wspace=0) if saveAllFigs or saveIntermediateFigs: figQ.savefig(os.path.join(outdir,f'out_Q_it_{it:06d}.png')) figErr.savefig(os.path.join(outdir,f'out_err_it_{it:06d}.png')) figS.savefig(os.path.join(outdir,f'out_S_it_{it:06d}.png')) # ===================================== # ===== Save all relevant outputs ===== # ===================================== # First, save the final registered versions of all relevant parameters np.savez(os.path.join(outdir,'saved_parameters.npz'), theta=theta.detach().cpu().numpy(), T=T.detach().cpu().numpy(), squish=squish.detach().cpu().numpy(), B=B.detach().cpu().numpy(), A=A.detach().cpu().numpy(), v=v.detach().cpu().numpy(), xv=np.array([x.detach().cpu().numpy() for x in xv],dtype=object), xId=np.array([x.detach().cpu().numpy() for x in xId],dtype=object), xJd=np.array([x.detach().cpu().numpy() for x in xJd],dtype=object) ) # Second, save the forward and inverse transform with torch.no_grad(): thetab = B@theta # blur Tb = B@T # blur T squishb = B@squish # blur Squish # Compute inverse affine Ai = torch.linalg.inv(A)[:3] Ai = torch.diag((-stretch).exp())@Ai Xs = (Ai[:3,:3]@XJd[...,None])[...,0] + Ai[:3,-1] # Since these operations do not change the z coordinate, I can interpolate them all at once tosample = torch.concatenate((Tb.T,squishb[None],thetab[None])) out = interp1d(xId,tosample,Xs,dd) Ts = out[0:2].permute(1,2,3,0) squishs = out[2] thetas = out[3] # These variables only need to be computed once eye = torch.diag(torch.ones(3,**dd),)[None,None,None].repeat(Ts.shape[0],Ts.shape[1],Ts.shape[2],1,1) zo = torch.tensor([0.0,0.0,0.0,1.0],**dd)[None,None,None,None].repeat(Ts.shape[0],Ts.shape[1],Ts.shape[2],1,1) z = torch.zeros_like(Ts[...,0,None],) # Remove the bottom row? Tcat = torch.concatenate((z,Ts),-1) squishmat = torch.diag_embed( torch.stack([torch.ones_like(squishs),(-squishs).exp(),(squishs).exp()] ,-1 ) ) rotmat = torch.stack([ torch.stack([torch.ones_like(thetas),torch.zeros_like(thetas),torch.zeros_like(thetas)],-1), torch.stack([torch.zeros_like(thetas),torch.cos(thetas),torch.sin(thetas)],-1), torch.stack([torch.zeros_like(thetas),-torch.sin(thetas),torch.cos(thetas)],-1) ],-2) Xs = Xs - Tcat Xs = (rotmat@squishmat@Xs[...,None])[...,0] # Now compute the diffeo phii = phii_from_v(xv,v) Xs = interp(xv,(phii-XV).permute(-1,0,1,2),Xs).permute(1,2,3,0) + Xs # Save the inverse transform np.savez(os.path.join(outdir,'inverse_transform.npz'), phii=Xs.detach().cpu().numpy(),x=np.array([x.detach().cpu().numpy() for x in xJd],dtype=object)) # Generate + save a high res version of the inverse XJd_ = torch.concatenate((XJd,torch.ones_like(XJd[...,0,None])),-1) Xs_ = torch.concatenate((Xs.detach(),torch.ones_like(XJd[...,0,None])),-1) fit = torch.linalg.solve( XJd_.reshape(-1,4).T@XJd_.reshape(-1,4), XJd_.reshape(-1,4).T @ Xs_.detach().reshape(-1,4) ).T FIT = (fit[:3,:3]@XJd_[...,:3,None])[...,0] + fit[:3,-1] fig,ax = plt.subplots(2,1) # hfig = display(fig,display_id=True) OUT = [] for i in range(0,J.shape[1],1): # get the location thisx = [[xJ[0][i]],xJ[1],xJ[2]] thisX = np.stack(np.meshgrid(*thisx,indexing='ij'),-1) # interpolate Xs, so that now when we fill with zeros, the result will be appropriate thisXfit = (fit[:3,:3]@thisX[...,None])[...,0] + fit[:3,-1] out = interp(xJd,(Xs-FIT).clone().detach().permute(-1,0,1,2),torch.tensor(thisX,**dd)).permute(1,2,3,0) + thisXfit.float() bad = out==0 test = interp(xId,Id,out) ax[0].cla() ax[0].imshow((out[0] - out.min())/(out.max()-out.min())) ax[0].set_title(i) ax[1].cla() ax[1].imshow(test.squeeze()) ax[1].set_title(i) # hfig.update(fig) # scale out = out - torch.tensor([xI[0][0],xI[1][0],xI[2][0]],**dd) out = out / ( torch.tensor([xI[0][1] - xI[0][0],xI[1][1] - xI[1][0],xI[2][1] - xI[2][0]],**dd)) OUT.append(out) OUT = torch.concatenate(OUT) OUT = OUT.numpy() # Write it out np.save(os.path.join(outdir,'interpolated_atlas_to_spine_reflection_v04.npy'),OUT) if saveAllFigs or saveFig13: fig.savefig(os.path.join(outdir, 'fig13_interp_atlas_to_spine.png')) # Save the forward and inverse transform with torch.no_grad(): # blur thetab = B@theta # blur T Tb = B@T # blur Squish squishb = B@squish # forward phi phi = phii_from_v(xv,-v.flip(0)) # start with XId Xs = XId.clone() # then apply phi Xs = interp(xv,(phi-XV).permute(-1,0,1,2),Xs).permute(1,2,3,0) + Xs # then we can compose the other transformations # since these operations do not change the z coordinate # I can interpolate them all at once tosample = torch.concatenate((Tb.T,squishb[None],thetab[None])) out = interp1d(xId,tosample,Xs,dd) Ts = out[0:2].permute(1,2,3,0) squishs = out[2] thetas = out[3] # these guys only need to be computed once eye = torch.diag(torch.ones(3,**dd),)[None,None,None].repeat(Ts.shape[0],Ts.shape[1],Ts.shape[2],1,1) zo = torch.tensor([0.0,0.0,0.0,1.0],**dd)[None,None,None,None].repeat(Ts.shape[0],Ts.shape[1],Ts.shape[2],1,1) z = torch.zeros_like(Ts[...,0,None],) Tcat = torch.concatenate((z,Ts),-1) # I'm leaving the sign here the same squishmat = torch.diag_embed( torch.stack([torch.ones_like(squishs),(squishs).exp(),(squishs).exp()] ,-1 ) ) # note I deleted the minus sign as compared to above rotmat = torch.stack([ torch.stack([torch.ones_like(thetas),torch.zeros_like(thetas),torch.zeros_like(thetas)],-1), torch.stack([torch.zeros_like(thetas),torch.cos(thetas),-torch.sin(thetas)],-1), # note I moved the minus sign as compared to above torch.stack([torch.zeros_like(thetas),torch.sin(thetas),torch.cos(thetas)],-1) ],-2) # before the order was translate, squish, rotate # so now we do, rotate, squish translate Xs = (squishmat@rotmat@Xs[...,None])[...,0] Xs = Xs + Tcat # last the affine, and the streching is part of it A_ = A[:3,:3]@torch.diag(stretch.exp()) + A[:3,-1] Xs = (A_[:3,:3]@Xs[...,None])[...,0] + A_[:3,-1] np.savez(os.path.join(outdir,'forward_transform.npz'), phi=Xs.detach().cpu().numpy(), x=np.array([x.detach().cpu().numpy() for x in xId],dtype=object)) # Now generate + save a high res version of the inverse XId_ = torch.concatenate((XId,torch.ones_like(XId[...,0,None])),-1) Xs_ = torch.concatenate((Xs.detach(),torch.ones_like(XId[...,0,None])),-1) fit = torch.linalg.solve( XId_.reshape(-1,4).T@XId_.reshape(-1,4), XId_.reshape(-1,4).T @ Xs_.detach().reshape(-1,4) ).T FIT = (fit[:3,:3]@XId_[...,:3,None])[...,0] + fit[:3,-1] fig,ax = plt.subplots(2,1) # hfig = display(fig,display_id=True) OUT = [] for i in range(0,I.shape[1],1): # get the location thisx = [[xI[0][i]],xI[1],xI[2]] thisX = np.stack(np.meshgrid(*thisx,indexing='ij'),-1) # interpolate Xs thisXfit = (fit[:3,:3]@thisX[...,None])[...,0] + fit[:3,-1] out = interp(xId,(Xs-FIT).clone().detach().permute(-1,0,1,2),torch.tensor(thisX,**dd)).permute(1,2,3,0) + thisXfit.float() bad = out==0 test = interp(xJd,Jd,out) ax[0].cla() ax[0].imshow((out[0] - out.min())/(out.max()-out.min())) ax[0].set_title(i) ax[1].cla() ax[1].imshow(test.squeeze()) ax[1].set_title(i) # hfig.update(fig) # scale out = out - torch.tensor([xJ[0][0],xJ[1][0],xJ[2][0]],**dd) out = out / ( torch.tensor([xJ[0][1] - xJ[0][0],xJ[1][1] - xJ[1][0],xJ[2][1] - xJ[2][0]],**dd)) out[bad] = -1 OUT.append(out) OUT = torch.concatenate(OUT) OUT = OUT.numpy() if saveAllFigs or saveFig14: fig.savefig(os.path.join(outdir, 'fig14_high_res.png')) if saveAllFigs or saveFig15: fig,ax = plt.subplots() ax.imshow((out[0] - out.min())/(out.max()-out.min())) fig.savefig(os.path.join(outdir, 'fig15_high_res0.png')) if saveAllFigs or saveFig16: i = Xs.shape[0]-1 fig,ax = plt.subplots() ax.imshow((Xs.detach()[i] - Xs.detach()[i].min())/(Xs.detach()[i].max()-Xs.detach()[i].min())) fig.savefig(os.path.join(outdir, 'fig16_high_res1.png')) if verbose: print(f'Successfully registered {fname_J} to {fname_I} and saved all outputs in {outdir}')
if __name__ == '__main__': main()