import emlddmm
import numpy as np
import torch
import argparse
from argparse import RawTextHelpFormatter
import json
import os
import pickle
import matplotlib.pyplot as plt
DEBUG = False
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'
dtype = torch.float
[docs]class Graph:
    """
    graph object with nodes and edges representing image spaces and transformations between them.
    Attributes
    ----------
    spaces : dict
        Integer keys map to the space name and the domain of the corresponding image space.
    adj: list
       Adjacency list. List of dictionaries holding the transforms needed to map between connecting spaces.
    """
    def __init__(self, adj=[], spaces={}):
        self.adj = adj
        self.spaces = spaces
    
[docs]    def add_space(self, space_name, x=[]):
        v = len(self.spaces)
        if space_name not in self.spaces:
            self.spaces.update({space_name: [v, x]}) 
[docs]    def add_edge(self, transforms, src_space, target_space):
        #print(f'adding edge from {src_space} to {target_space}')
        #print(f'source id is {self.spaces[src_space][0]}')
        #print(f'target id is {self.spaces[target_space][0]}')
        self.adj[self.spaces[src_space][0]].update({self.spaces[target_space][0]: transforms}) 
[docs]    def BFS(self, src, target, v, pred, dist):
        """ Breadth first search
        a modified version of BFS that stores predecessor
        of each vertex in array pred and its distance from source in array dist
        Parameters
        ----------
        src: int
            int value given by corresponding src in spaces dict
        target: int 
            int value given by corresponding target in spaces dict
        v: int
            length of spaces dict
        pred: list of ints
            stores predecessor of vertex i at pred[i]
        dist: list of ints
            stores distance (by number of vertices) of vertex i from source vertex
        Returns
        -------
        bool
            True if a path from src to target is found and False otherwise
        """
    
        queue = []
    
        visited = [False for i in range(v)]
        # for each space we initialize the distance from src to be a large number and the predecessor to be -1
        for i in range(v):
    
            dist[i] = 1000000
            pred[i] = -1
        
        # visit source first. Distance from source to itself is 0
        visited[src] = True
        dist[src] = 0
        queue.append(src)
    
        # BFS algorithm
        while (len(queue) != 0):
            u = queue[0]
            queue.pop(0)
            for i in range(len(self.adj[u])):
            
                if (visited[list(self.adj[u])[i]] == False):
                    visited[list(self.adj[u])[i]] = True
                    dist[list(self.adj[u])[i]] = dist[u] + 1
                    pred[list(self.adj[u])[i]] = u
                    queue.append(list(self.adj[u])[i])
    
                    # We stop BFS when we find
                    # destination.
                    if (list(self.adj[u])[i] == target):
                        return True
    
        return False 
[docs]    def shortest_path(self, src, target):
        """ Find Shortest Path
        Finds the shortest path between target and src in the adjacency list and prints its length
        Parameters
        ----------
        src: int 
            src value given by corresponding source in spaces dict
        target: int
            int value given by corresponding source in spaces dict
        Returns
        -------
        path : list of ints
            path from target to src using integer values of the adjacency list vertices. Integers can be converted to space names by the spaces dict.
        Example
        -------
        >>> adj = [{1: ('outputs/MRI/HIST_REGISTERED_to_MRI/', 'f')},
        ...        {2: ('outputs/CCF/MRI_to_CCF/', 'f'), 0: ('outputs/MRI/HIST_REGISTERED_to_MRI/', 'b')},
        ...        {1: ('outputs/CCF/MRI_to_CCF/', 'b')},
        ...        {}]
        >>> path = find_shortest_path(adj, 0, 2, 4)
        Shortest path length is: 2
        >>> print(path)
        [0,1,2]
        >>> path = transformation_graph.find_shortest_path(adj, 0, 3, 4)
        Given target and source are not connected
        """
        v = len(self.spaces)
        
        pred=[0 for i in range(v)] # predecessor of space i in path from target to src
        dist=[0 for i in range(v)] # distance of vertex i by number of vertices from src
    
        if (self.BFS(src, target, v, pred, dist) == False):
            print("Given target and source are not connected") # TODO: make this more informative
            print('printing source')
            print(src)
            print('printing target')
            print(target)
            print('printing spaces')
            print(self.spaces)
            print('printing adjacency')
            for i in range(len(self.adj)):
                print(i)                
                print(self.adj[i])
            
        # path stores the shortest path
        path = []
        crawl = target
        path.append(crawl)
        
        while (pred[crawl] != -1):
            path.append(pred[crawl])
            crawl = pred[crawl]
        
        path.reverse()
        # if len(path) > 1:
            # distance from source is in distance array
            # print(f"Shortest path length is: {dist[target]} \n")
        return path 
    
    
[docs]    def map_points_(self, src_space, transforms, xy_shift=None):
        # this is replaced below
        '''Applies a sequence of transforms to points in source space. If mapping to an image series, maps to the registered domain.
        Parameters
        ----------
        srs_space : str
            name of the source space
        transforms : list of emlddmm Transform objects
        xy_shift : torch Tensor
            R^2 vector. Applies a translation in xy for reconstructing in registered space. 
        Returns
        -------
        X : torch tensor
            transformed points
        
        Note
        -----
        The xy_shift optional argument is necessary when reconstructing in registered space or when the origins of the source and target space
        are far apart, e.g. one defines the origin at bregma and the other at the image center. The problem arises because the xy translation is
        contained entirely in the 2D affine transforms. 
        
        
        '''
        xI = self.spaces[src_space][1]
        xI = [torch.as_tensor(x) for x in xI]
        XI = torch.stack(torch.meshgrid(xI, indexing='ij'))
        # series of 2d transforms can only be applied to the space on which they were computed because the size must match the number of slices.
        # to avoid an error, check if a volumetric transform is followed by a 2d series transform.
        check = False
        ids = []
        for i in range(len(transforms)):
            A2d = transforms[i].data.ndim == 3
            if A2d and check:
                ids.append(i)
            elif not A2d:
                check = True
        transforms = [j for i, j in enumerate(transforms) if i not in ids]
        X = emlddmm.compose_sequence(transforms, XI)
        if xy_shift is not None:
            X[1:] += xy_shift[...,None,None,None]
        return X 
    
[docs]    def map_points(self, src_space, transforms, xy_shift=None, slice_locations=None):
        # this is reimplemented
        '''Applies a sequence of transforms to points in source space. 
        Parameters
        ----------
        srs_space : str
            name of the source space
        transforms : list of emlddmm Transform objects
        xy_shift : torch Tensor
            R^2 vector. Applies a translation in xy for reconstructing in registered space. 
        Returns
        -------
        X : torch tensor
            transformed points
        
        Note
        -----
        The xy_shift optional argument is necessary when reconstructing in registered space or when the origins of the source and target space
        are far apart, e.g. one defines the origin at bregma and the other at the image center. The problem arises because the xy translation is
        contained entirely in the 2D affine transforms. 
        
        Note
        ----
        From daniel.  Mapping from 2D slices to anything is fine.  The shape will always be the 2D slices.  This is used when mapping imaging data to a 2D slice.
        Mapping to a 2D slice is harder. But I think we can do it somehow with nearest neighbor interpolation.
        
        '''
        xI = self.spaces[src_space][1]
        xI = [torch.as_tensor(x) for x in xI]
        XI = torch.stack(torch.meshgrid(xI, indexing='ij'))        
        X = torch.as_tensor(XI.clone(),dtype=transforms[0].data.dtype)  # I do not want to share data, still this prints a warning
        for t in transforms:
            # let's check for special cases
            if t.data.ndim == 3:
                #print('this is a series of 2d affine')
                if X.shape[1] == t.data.shape[0]:
                    #print('shapes are compatible')
                    X = t.apply(X)
                else:
                    # TODO (done)
                    # for each z coordinate, find the closest slice
                    # then map the xy coordinates based on the matrix for their closest slice
                    # also, snap the z coordinate exactly to the slice, I think this will be necessary for interpolation      
                    # this will happen when mapping imaging data from a 2d space
                    # in this case we need slice_locations as an optional argument
                    #               
                    print('*'*80)
                    print('shapes not compatible')
                    if slice_locations is None:
                        print('skipping for now because there are no slice locations')
                        continue
                    else:
                        print('using slice snapping')
                    
                    X0ind = torch.round( (X[0] - slice_locations[0])/(slice_locations[1] - slice_locations[0]) ).int() # the slice coordinate
                    X0ind[X0ind<0]=0
                    X0ind[X0ind>=len(slice_locations)] = len(slice_locations)-1
                    Xnew = X.clone()
                    for i in range(len(slice_locations)):
                        ind = X0ind == i
                        X12 = (t.data[i,:2,:2]@(X[1:,ind])) + t.data[i,:2,-1,None]                        
                        X0 = slice_locations[i]
                        # assign
                        Xnew[0,ind] = X0
                        Xnew[1:,ind] = X12
                        X = Xnew
            else:
                # simple case
                X = t.apply(X)
                                            
                    
        # not doing this
        #X = emlddmm.compose_sequence(transforms, XI)
        if xy_shift is not None:
            X[1:] += xy_shift[...,None,None,None]
        return X     
[docs]    def map_image_(self, src_space, image, target_space, transforms, xy_shift=None):
        # this one is obsolte
        '''Map an image from source space to target space.
        Parameters
        ----------
        src_space : str
            name of source space
        image : array
        target_space : str
            name of target space
        transforms : list of emlddmm Transform objects
        xy_shift : torch Tensor
            R^2 vector. Applies a translation in xy for reconstructing in registered space.
        Returns
        -------
        image : array
            transformed image data
        Note
        -----
        The xy_shift optional argument is necessary when reconstructing in registered space or when the origins of the source and target space
        are far apart, e.g. one defines the origin at bregma and the other at the image center. The problem arises because the xy translation is
        contained entirely in the 2D affine transforms
        '''
        # if the last transforms are 2d series affine, then we will first apply them to the target space and resample the target image in registered space.
        # then apply the other transforms to the source space and resample the registered target image.
        ids = []
        A2d = []
        for i in reversed(range(len(transforms))):
            if transforms[i].data.ndim == 3:
                ids.append(i)
                A2d.insert(0,transforms[i])
            else:
                break
        transforms = [j for i, j in enumerate(transforms) if i not in ids]
        xI = self.spaces[src_space][1]
        xI = [torch.as_tensor(x) for x in xI]
        if len(A2d) > 0:
            image = torch.as_tensor(image)
            XI = torch.stack(torch.meshgrid(xI, indexing='ij'))
            if xy_shift is not None:
                XI[1:] -= xy_shift[...,None,None,None]
            XR = emlddmm.compose_sequence(A2d, XI)
            image = emlddmm.interp(xI,image, XR)
        # any 2D affine transforms at the end of the sequence will be ignored by map_points
        if len(transforms) > 0:
            X = self.map_points(target_space, transforms, xy_shift=xy_shift)
            image = emlddmm.interp(xI, image, X)
        return image 
[docs]    def map_image(self, src_space, image, target_space, transforms, xy_shift=None, **kwargs):
        '''Map an image from source space to target space.
        Parameters
        ----------
        src_space : str
            name of source space
        image : array
        target_space : str
            name of target space
        transforms : list of emlddmm Transform objects
        xy_shift : torch Tensor
            R^2 vector. Applies a translation in xy for reconstructing in registered space.
        kwargs : dict
            keword args to be passed to emlddmm interpolation, which will be passed to torch grid sample
        Returns
        -------
        image : array
            transformed image data
        Note
        -----
        The xy_shift optional argument is necessary when reconstructing in registered space or when the origins of the source and target space
        are far apart, e.g. one defines the origin at bregma and the other at the image center. The problem arises because the xy translation is
        contained entirely in the 2D affine transforms
        '''
        
        xI = self.spaces[src_space][1]
        xI = [torch.as_tensor(x) for x in xI]
        image = torch.as_tensor(image)
        if isinstance(transforms,list):
            X = self.map_points(target_space, transforms, xy_shift=xy_shift)    
        else:
            X = transforms
        image = emlddmm.interp(xI, image, X,**kwargs)
        return image 
[docs]    def merge(self, new_graph):
        ''' Merge two graphs
        Parameters
        ----------
        new_graph : emlddmm Graph object
        Returns
        -------
        graph : emlddmm Graph object
            Current graph merged with the new graph.
        '''
        graph = self
        # merge spaces dict
        id_map = {} # we need a dict to map new space indices to the existing ones
        for key, value in new_graph.spaces.items():
            if key in graph.spaces:
                id_map.update({value[0]:graph.spaces[key][0]})
            else:
                id_map.update({value[0]:len(graph.spaces)})
                value[0] = len(graph.spaces)
                graph.spaces.update({key:value})
                graph.adj.append({}) # add a node in the adjacency list for each new space
        # merge adjacency list
        for i, node in enumerate(new_graph.adj): # for each node in the graph
            src = id_map[i] # get the original index for the space
            for j in node: # and for each node to which it connects
                target = id_map[j]
                transform = node[j]
                graph.adj[src].update({target:transform}) # update the graph
        return graph  
[docs]def graph_reconstruct_(graph, out, I, target_space, target_fnames=[]):
    # this version is obsolete
    ''' Apply Transformation
    Applies affine matrix and velocity field transforms to map source points to target points. Saves displacement field from source points to target points
    (i.e. difference between transformed coordinates and input coordinates), and determinant of Jacobian for 3d source spaces. Also saves transformed image in vtk format.
    Parameters
    ----------
    graph : emlddmm Graph object
    out: str
        path to registration outputs parent directory
    I : emlddmm Image   
    target_space : str
        name of the space to which image I will be transformed.
    target_fnames : list
        list of file names; only necessary if target is a series of 2d slices.
        
    TODO
    ----
    Check why the registered space histology is not working. (march 27, 2023)
    I think the issue is that there is actually no time to do it.
    If I say to reconstruct one space to itself, then it says not connected and gives an error.  
    There needs to be another way.
    '''
    jacobian = lambda X,dv : np.stack(np.gradient(X, dv[0],dv[1],dv[2], axis=(1,2,3))).transpose(2,3,4,0,1)
    dtype = torch.float
    device = 'cpu'
    # convert data to torch
    # J.x = [torch.as_tensor(x, dtype=dtype, device=device) for x in J.x]
    # J.data = torch.as_tensor(J.data, dtype=dtype, device=device)
    # first we get the sample points in the target space J
    xJ = graph.spaces[target_space][1]
    xJ = [torch.as_tensor(x, dtype=dtype, device=device) for x in xJ]
    target_space_idx = graph.spaces[target_space][0]
    # then we get the sample points in the source space I
    I.x = [torch.as_tensor(x, dtype=dtype, device=device) for x in I.x]
    I.data = torch.as_tensor(I.data, dtype=dtype, device=device)
    src_space_idx = graph.spaces[I.space][0]
    # backward transform, map the points in target space J back to the source space
    path = graph.shortest_path(target_space_idx, src_space_idx)
    transforms = graph.transforms(path)
    XJ = torch.stack(torch.meshgrid(xJ, indexing='ij'))    
    fXJ = graph.map_points(target_space, transforms)
    # and then transform the image by sampling it at these points
    fI = graph.map_image(I.space, I.data, target_space, transforms)
    # now we are going to write the outputs.  This involves several different cases
    '''
    Three cases:
    1) series to series
        Save reconstructions of I slices in J space. 
        Apply A2di to J points and resample I.
    2) volume to series
        A) Save volume to registered images in {target_space}_registered/{I.space}_{I.name}_to_{target_space}_registered/images/,
        and volume to input images in {target_space}_input/{I.space}_{I.name}_to_{target_space}_input/images/
        B) Save volume to registered and volume to input displacement in {target_space}_registered/{I.space}_{I.name}_to_{target_space}_registered/transforms/ and 
        {target_space}_input/{I.space}_{I.name}_to_{target_space}_input/transforms/, respectively.
    3) series to volume
        A) save series input to registered space images in {I.space}_registered/{I.space}_{I.name}_input_to_{I.space}_registered/images/
        C) Save series to volume image in {target_space}/{I.space}_{I.name}_input_to_{I.space}/images/
        D) Save series to volume detjac and displacement in {target_space}/{I.space}_{I.name}_registered_to_{target_space}/transforms/
    4) volume to volume 
        A) Save out I to J image in {target_space}/{I.space}_to_{target_space}/images.
        B) Save out I to J detjac and displacement in {target_space}/{I.space}_to_{target_space}/transforms/
    '''
    from_series = I.title == 'slice_dataset'
    to_series = len(target_fnames) != 0 # we don't have the image title so we need to check for a list of file names
    if from_series and to_series:
        print(f'reconstructing {I.space} {I.name} in {target_space} space')
        # series to series
        # Assumes J and I have the same space dimensions
        I_to_J_out = os.path.join(out, f'{target_space}_input/{I.space}_{I.name}_input_to_{target_space}_input/images/')
        if not os.path.exists(I_to_J_out):
            os.makedirs(I_to_J_out)
        for i in range(I.data.shape[1]):
            x = [[I.x[0][i], I.x[0][i]+10], I.x[1], I.x[2]]
            # I to J
            img = fI[:, i, None, ...]
            title = f'{I.space}_input_{I.fnames()[i]}_to_{target_space}_input_{target_fnames[i]}'
            emlddmm.write_vtk_data(os.path.join(I_to_J_out, f'{I.space}_input_{I.fnames()[i]}_to_{target_space}_input_{target_fnames[i]}.vtk'), x, img, title)
    elif to_series:
        print(f'reconstructing {I.space} {I.name} in {target_space} space')
        # volume to series
        # we need I transformed to J registered space
        path = graph.shortest_path(graph.spaces[target_space][0], graph.spaces[I.space][0])
        # omit the first 2d series transforms (R^-1) which takes points from 2d to 2d or input to registered.
        # in this case we can just remove all 2d series transforms.
        transforms = graph.transforms(path)
        for i, t in enumerate(transforms):
            if t.data.ndim == 3:
                mean_translation = torch.mean(t.data[:,:2,-1], dim=0)
                del transforms[i]
        phiiAiXJ = graph.map_points(target_space, transforms, xy_shift=mean_translation)
        AphiI = graph.map_image(I.space, I.data, target_space, transforms, xy_shift=mean_translation)
        # get I to J registered and I to J input displacements
        reg_disp = (phiiAiXJ - XJ)[None]
        input_disp = (fXJ - XJ)[None]
        # setup output paths
        I_to_Ji_out = os.path.join(out, f'{target_space}_input/{I.space}_{I.name}_to_{target_space}_input/images/')
        if not os.path.exists(I_to_Ji_out):
            os.makedirs(I_to_Ji_out)
        I_to_Jr_out = os.path.join(out, f'{target_space}_registered/{I.space}_{I.name}_to_{target_space}_registered/images/')
        if not os.path.exists(I_to_Jr_out):
            os.makedirs(I_to_Jr_out)
        reg_disp_out = os.path.join(out, f'{target_space}_registered/{I.space}_{I.name}_to_{target_space}_registered/transforms/')
        if not os.path.exists(reg_disp_out):
            os.makedirs(reg_disp_out)
        input_disp_out = os.path.join(out, f'{target_space}_input/{I.space}_{I.name}_to_{target_space}_input/transforms/')
        if not os.path.exists(input_disp_out):
            os.makedirs(input_disp_out)
        for i in range(len(xJ[0])):
            x = [[xJ[0][i], xJ[0][i]+10], xJ[1], xJ[2]]
            # volume to input series
            # save image
            img = fI[:, i, None, ...]
            title = f'{I.space}_{I.name}_to_{target_space}_input_{target_fnames[i]}'
            emlddmm.write_vtk_data(os.path.join(I_to_Ji_out, title + '.vtk'), x, img, title)
            # save displacement
            disp = input_disp[:, :, i, None]
            title = f'{target_space}_input_{target_fnames[i]}_to_{I.space}_displacement'
            emlddmm.write_vtk_data(os.path.join(input_disp_out, title + '.vtk'), x, disp, title)
            # volume to registered series
            img = AphiI[:, i, None, ...]
            title = f'{I.space}_{I.name}_to_{target_space}_registered_{target_fnames[i]}'
            emlddmm.write_vtk_data(os.path.join(I_to_Jr_out, title + '.vtk'), x, img, title)
            # save displacement
            disp = reg_disp[:, :, i, None]
            title = f'{target_space}_registered_{target_fnames[i]}_to_{I.space}_displacement'
            emlddmm.write_vtk_data(os.path.join(reg_disp_out, title + '.vtk'), x, disp, title)
    
    elif from_series:
        print(f'reconstructing {I.space} {I.name} in {target_space} space')
        # series to volume
        # I input to I registered space
        path = graph.shortest_path(graph.spaces[target_space][0], graph.spaces[I.space][0])
        # get the last 2d series transforms (R) which take points from 2d to 2d or registered to input
        transforms = graph.transforms(path)
        idx = 0
        for i,t in enumerate(transforms[::-1]):
            if t.data.ndim != 3:
                idx = i
                break
        A2ds = transforms[-idx:]
        mean_translation = torch.mean(A2ds[0].data[:,:2,-1], dim=0)
        RiI = graph.map_image(I.space, I.data, I.space, A2ds, xy_shift=mean_translation)
        Ii_to_Ir_out = os.path.join(out, f'{I.space}_registered/{I.space}_input_to_{I.space}_registered/images/')
        if not os.path.exists(Ii_to_Ir_out):
            os.makedirs(Ii_to_Ir_out)
        # input to registered images
        img = RiI[:, i, None, ...]
        title = f'{I.space}_input_{I.fnames()[i]}_to_{I.space}_registered_{I.fnames()[i]}'
        emlddmm.write_vtk_data(os.path.join(Ii_to_Ir_out, title + '.vtk'), I.x, img, title)
        # I to J
        img = graph.map_image(I.space, I.data, target_space, transforms, xy_shift=mean_translation)
        title = f'{I.space}_{I.name}_input_to_{target_space}'
        I_to_J_imgs = os.path.join(out, f'{target_space}/{I.space}_{I.name}_input_to_{target_space}/images/')
        if not os.path.exists(I_to_J_imgs):
            os.makedirs(I_to_J_imgs)
        emlddmm.write_vtk_data(os.path.join(I_to_J_imgs, title + '.vtk'), xJ, img, title)
        # disp
        # we need J to I registered points
        disp = (fXJ - XJ)[None]
        title = f'{I.space}_{I.name}_registered_to_{target_space}_displacement'
        I_to_J_transforms = os.path.join(out, f'{target_space}/{I.space}_{I.name}_registered_to_{target_space}/transforms/')
        if not os.path.exists(I_to_J_transforms):
            os.makedirs(I_to_J_transforms)
        emlddmm.write_vtk_data(os.path.join(I_to_J_transforms, title + '.vtk'), xJ, disp, title)
        # determinant of jacobian
        dv = [(x[1]-x[0]) for x in xJ]
        jac = jacobian(fXJ, dv)
        detjac = np.linalg.det(jac)[None]
        title = f'{I.space}_{I.name}_registered_to_{target_space}_detjac'
        emlddmm.write_vtk_data(os.path.join(I_to_J_transforms, title + '.vtk'), xJ, detjac, title)
    else: # this is the volume to volume case
        print(f'reconstructing {I.space} {I.name} in {target_space} space')
        # volume to volume
        # I to J
        # save image
        img = fI
        title = f'{I.space}_{I.name}_to_{target_space}'
        I_to_J_imgs = os.path.join(out, f'{target_space}/{I.space}_{I.name}_to_{target_space}/images/')        
        if not os.path.exists(I_to_J_imgs):
            os.makedirs(I_to_J_imgs)
        emlddmm.write_vtk_data(os.path.join(I_to_J_imgs, title + '.vtk'), xJ, img, title)
        # save displacement
        disp = (fXJ - XJ)[None]
        title = f'{I.space}_{I.name}_to_{target_space}_displacement'
        I_to_J_transforms = os.path.join(out, f'{target_space}/{I.space}_{I.name}_to_{target_space}/transforms/')
        if not os.path.exists(I_to_J_transforms):
            os.makedirs(I_to_J_transforms)
        emlddmm.write_vtk_data(os.path.join(I_to_J_transforms, title + '.vtk'), xJ, disp, title)
        # save determinant of jacobian
        dv = [(x[1]-x[0]) for x in xJ]
        jac = jacobian(fXJ, dv)
        detjac = np.linalg.det(jac)[None]
        title = f'{I.space}_{I.name}_to_{target_space}_detjac'
        emlddmm.write_vtk_data(os.path.join(I_to_J_transforms, title + '.vtk'), xJ, detjac, title) 
[docs]def graph_reconstruct(graph, out, I, target_space, target_fnames=[]):
    # this version is modified by daniel and does not treat "registered as a special case"
    # todo, transform all images.  every time you transform an image, you also write it out, and write out the transform
    # if you transform an annotation image to a slice dataset, you should also output geojson
    # in this case we'll have to find an ontology to work with cshl
    ''' Apply Transformation
    Applies affine matrix and velocity field transforms to map source points to target points. Saves displacement field from source points to target points
    (i.e. difference between transformed coordinates and input coordinates), and determinant of Jacobian for 3d source spaces. Also saves transformed image in vtk format.
    Parameters
    ----------
    graph : emlddmm Graph object
    out: str
        path to registration outputs parent directory
    I : emlddmm Image   
    target_space : str
        name of the space to which image I will be transformed.
    target_fnames : list
        list of file names; only necessary if target is a series of 2d slices.
        
    TODO
    ----
    Check why the registered space histology is not working. (march 27, 2023)
    I think the issue is that there is actually no time to do it.
    If I say to reconstruct one space to itself, then it says not connected and gives an error.  
    There needs to be another way.
    '''
    jacobian = lambda X,dv : np.stack(np.gradient(X, dv[0],dv[1],dv[2], axis=(1,2,3))).transpose(2,3,4,0,1)
    dtype = torch.float
    device = 'cpu'
    
    print(f'about to transform {I.space}, {I.title} to the space {target_space}')
    
    # now we are going to write the outputs.  This involves several different cases
    '''
    Three cases:
    1) series to series
        Save reconstructions of I slices in J space. 
        Apply A2di to J points and resample I.
    2) volume to series
        A) Save volume to registered images in {target_space}_registered/{I.space}_{I.name}_to_{target_space}_registered/images/,
        and volume to input images in {target_space}_input/{I.space}_{I.name}_to_{target_space}_input/images/
        B) Save volume to registered and volume to input displacement in {target_space}_registered/{I.space}_{I.name}_to_{target_space}_registered/transforms/ and 
        {target_space}_input/{I.space}_{I.name}_to_{target_space}_input/transforms/, respectively.
    3) series to volume
        A) save series input to registered space images in {I.space}_registered/{I.space}_{I.name}_input_to_{I.space}_registered/images/
        C) Save series to volume image in {target_space}/{I.space}_{I.name}_input_to_{I.space}/images/
        D) Save series to volume detjac and displacement in {target_space}/{I.space}_{I.name}_registered_to_{target_space}/transforms/
    4) volume to volume 
        A) Save out I to J image in {target_space}/{I.space}_to_{target_space}/images.
        B) Save out I to J detjac and displacement in {target_space}/{I.space}_to_{target_space}/transforms/
    '''
    from_series = I.title == 'slice_dataset'
    to_series = len(target_fnames) != 0 # we don't have the image title so we need to check for a list of file names
    print(f'Is the source a series? {from_series}')
    print(f'Is the target a series? {to_series}')
    # convert data to torch
    # J.x = [torch.as_tensor(x, dtype=dtype, device=device) for x in J.x]
    # J.data = torch.as_tensor(J.data, dtype=dtype, device=device)
    # first we get the sample points in the target space J
    xJ = graph.spaces[target_space][1]
    xJ = [torch.as_tensor(x, dtype=dtype, device=device) for x in xJ]
    target_space_idx = graph.spaces[target_space][0]
    # then we get the sample points in the source space I            
    src_space_idx = graph.spaces[I.space][0]
    # backward transform, map the points in target space J back to the source space
    path = graph.shortest_path(target_space_idx, src_space_idx)
    transforms = graph.transforms(path)
    if DEBUG:
        print(transforms)
        for t in transforms:
            if t.data.shape == torch.Size([4,4]):
                print(t.data)        
    XJ = torch.stack(torch.meshgrid(xJ, indexing='ij'))    
    if from_series and not to_series:
        # special case for mapping 2D data to 3D
        # we need to snap onto grid points
        #         
        # note in the case of atlas to registered space
        # there is no issue, since the transforms are just v A
        # the slice locations argument will be ignored        
        fXJ = graph.map_points(target_space, transforms, slice_locations=I.x[0])
    else:
        fXJ = graph.map_points(target_space, transforms)
    
    # and then transform the image by sampling it at these points
    if I.data is not None:
        I.x = [torch.as_tensor(x, dtype=dtype, device=device) for x in I.x]
        Idtype = I.data.dtype
        I.data = torch.as_tensor(I.data, dtype=dtype, device=device)
        
        if I.annotation:            
            print('found annotation, using nearest')
            fI = graph.map_image(I.space, I.data, target_space, fXJ, mode='nearest')
            # convert the dtype back
            if Idtype == np.uint8:
                fI = torch.as_tensor(fI,dtype=torch.uint8)
            elif Idtype == np.uint16:
                fI = torch.as_tensor(fI,dtype=torch.uint16) # I think this doesn't exist
            elif Idtype == np.uint32:
                fI = torch.as_tensor(fI,dtype=torch.uint32) 
        else:
            fI = graph.map_image(I.space, I.data, target_space, fXJ)
        
        # write out a qc file        
        fig,ax = emlddmm.draw(fI,xJ)
        qc_out = os.path.join(out, f'{target_space}/{I.space}_to_{target_space}/qc/')
        os.makedirs(qc_out, exist_ok=True)
        fig.savefig(os.path.join(qc_out,f'{I.space}_{I.name}_to_{target_space}_full.jpg'))
        
    else:
        fI = None
    # daniel asks, does the above map points and images work in all cases? 
    
    
    
    
    if from_series and to_series:
        print('This is a 2D series to a 2D series')
        # this one is pretty straight forward
        # we write out transforms and we write out images
        # series to series
        # Assumes J and I have the same space dimensions
        I_to_J_out = os.path.join(out, f'{target_space}/{I.space}_to_{target_space}/images/')
        reg_disp_out = os.path.join(out, f'{target_space}/{I.space}_to_{target_space}/transforms/')
        
        if fI is not None:
            os.makedirs(I_to_J_out, exist_ok=True)
        
        os.makedirs(reg_disp_out, exist_ok=True)
        for i in range(len(xJ[0])):#range(I.data.shape[1]): # note the number of slices in I and J must be equal
            
            # I to J
            if fI is not None: # if there are no images, don't write them
                x = [[I.x[0][i], I.x[0][i]+10], I.x[1], I.x[2]]
                img = fI[:, i, None, ...]
                title = f'{I.space}_{I.name}_{I.fnames()[i]}_to_{target_space}_{target_fnames[i]}'
                emlddmm.write_vtk_data(os.path.join(I_to_J_out, f'{I.space}_{I.name}_{I.fnames()[i]}_to_{target_space}_{target_fnames[i]}.vtk'), x, img, title)
            # also write out transforms as a matrix
            #print(transforms)
            if np.all([t.data.ndim==3 for t in transforms]):
                #print('All transforms are matrices')
                #print(f'{transforms[0].data.shape}')
                output = transforms[0].data[i].clone()
                for t in transforms[1:]:
                    output = t.data[i]@output
                # TODO: something is wrong with this path, come back to it later
                output_transform_name = os.path.join(reg_disp_out,f'{target_space}_{target_fnames[i]}_to_{I.space}_{I.fnames()[i]}_matrix.txt')
                print(output_transform_name)
                emlddmm.write_matrix_data(output_transform_name, output)
                
            else:
                # write out displacement fields                
                print('writing displacement fields for 2d to 2d not implement yet')
                asdf
                pass
        
    elif to_series:
        print('This is 3D to 2D')
        # recall I don't need special cases for registered space anymore
        reg_disp = (fXJ - XJ)[None] # recall we output a 1x3xslicexrowxcol
        # get the output 
        # setup output paths
        I_to_J_out = os.path.join(out, f'{target_space}/{I.space}_to_{target_space}/images/')        
        os.makedirs(I_to_J_out, exist_ok=True)        
        reg_disp_out = os.path.join(out, f'{target_space}/{I.space}_to_{target_space}/transforms/')        
        os.makedirs(reg_disp_out, exist_ok=True)
        if I.annotation:
            #print('*'*80)
            #print('This is an annotation, we will make geojson outputs')
            geojson_out = os.path.join(out, f'{target_space}/{I.space}_to_{target_space}/geojson/')
            os.makedirs(geojson_out, exist_ok=True)
        
        qc_out = os.path.join(out, f'{target_space}/{I.space}_to_{target_space}/qc/')
        os.makedirs(qc_out, exist_ok=True)
        fig,ax = plt.subplots()
        for i in range(len(xJ[0])): # loop over all slices
            ax.cla()
            x = [[xJ[0][i], xJ[0][i]+10], xJ[1], xJ[2]] #TODO this +10 is hard coded just to make it work in 3D, it's not actually slice thickness
            # save image
            img = fI[:, i, None, ...]
            title = f'{I.space}_{I.name}_to_{target_space}_{target_fnames[i]}'
            emlddmm.write_vtk_data(os.path.join(I_to_J_out, title + '.vtk'), x, img, title)
            # save displacement
            disp = reg_disp[:, :, i, None]
            title = f'{target_space}_{target_fnames[i]}_to_{I.space}_displacement'
            emlddmm.write_vtk_data(os.path.join(reg_disp_out, title + '.vtk'), x, disp, title)
            # TODO: for qc I would need to get the target fnames and load them
            # TODO: we need to load an ontology
            if I.annotation:                
                output_geojson = {'type': 'FeatureCollection', 'features': []}
                # generate geojson curves for each label
                labels = np.unique(img.cpu().numpy())
                
                count = 0
                for l in labels[1:]: # ignore background label
                    coordinates = [] # one set of coordinates per label
                    cs = ax.contour(xJ[-1],xJ[-2],(img.cpu().numpy()[0,0]==l).astype(float),[0.5],linewidths=1.0,colors='k')
                    paths = cs.collections[0].get_paths()
                    for path in paths:
                        vertices = np.array([seg[0] for seg in path.iter_segments()])
                        meanpos = (np.max(vertices,0) + np.min(vertices,0))/2
                        # put some text
                        if vertices.shape[0] > 20:
                            ax.text(meanpos[0],meanpos[1],str(l),
                                    horizontalalignment='center', verticalalignment='center',
                                    fontsize=4, bbox={'color':np.array([1.0,1.0,1.0,0.5]), 'pad':0})
                        coordinates.append([[ list(seg[0]) for seg in path.iter_segments()]])
                    # TODO get an actual ontology
                    geometry = {'type': 'MultiPolygon', 'coordinates': coordinates}
                    
                    # TODO: for the first feature add something to properties that says how big the suggested image is
                    # it should be a dictionary like 'suggested_image':{'n':[],'o':[],'d':[]}
                    # we will use the convention that we go from the first pixel to the last pixel i
                    
                    properties = {'name':str(l), 'acronym':str(l)}
                    if count == 0:
                        # what does the 32x upsampled space look like?
                        # each pixel becomes 32 pixels, centered at the same spot
                        nup = 32
                        offsets = (np.arange(nup) - (nup-1)/2)/nup
                        dJ = np.array([np.array(x[1] - x[0]) for x in xJ])
                        
                        tmp0 = (np.array(xJ[-2][...,None]) + offsets*np.array(dJ[-2])).reshape(-1)
                        tmp1 = (np.array(xJ[-1][...,None]) + offsets*np.array(dJ[-1])).reshape(-1)
                        xup = [tmp0,tmp1]
                        properties['suggested_image'] = {'n':[len(x) for x in xup],
                                                         'd':[dJ[1].item()/nup,dJ[2].item()/nup],
                                                         'o':[xup[0][0].item(),xup[1][0].item()]}                        
                    output_geojson['features'].append({'type': 'Feature', 'id': int(l), 'properties': properties, 'geometry': geometry})
                    
                    count += 1                
                with open(os.path.join(geojson_out,f'{I.space}_{I.name}_to_{target_space}_{target_fnames[i]}.geojson'),'wt') as f:
                    json.dump(output_geojson, f, indent=2)
                fig.savefig(os.path.join(qc_out,f'{I.space}_{I.name}_to_{target_space}_{target_fnames[i]}.jpg'))
        plt.close(fig)
                    
    elif from_series:
        print(f'reconstructing {I.space} {I.name} in {target_space} space')
        print('This is mapping 2D images to a 3D space,')
        
        # again, we don't do anything special here with introducing new spaces        
        # we just write out fI and FXJ
        reg_disp = (fXJ - XJ)[None]                
        # set up the paths
        I_to_J_out = os.path.join(out, f'{target_space}/{I.space}_to_{target_space}/images/')
        if fI is not None:
            os.makedirs(I_to_J_out, exist_ok=True)
        reg_disp_out = os.path.join(out, f'{target_space}/{I.space}_to_{target_space}/transforms/')
        os.makedirs(reg_disp_out, exist_ok=True)
        # write out image
        if fI is not None:
            title = f'{I.space}_{I.name}_to_{target_space}'
            emlddmm.write_vtk_data(os.path.join(I_to_J_out, title + '.vtk'), xJ, fI, title)
        # write out transform, note I have modified map points so these will point exactly at a slice, so that we can do matrix multiplication, if necessary
        title = f'{target_space}_to_{I.space}_displacement'
        emlddmm.write_vtk_data(os.path.join(reg_disp_out, title + '.vtk'), xJ, reg_disp, title)
    else: # this is the volume to volume case
        print(f'reconstructing {I.space} {I.name} in {target_space} space')
        print(f'This is mapping a 3D image to a 3D image')
        # question, does the registered space count as a 3D image? No, not for our purposes
        
        
        # volume to volume
        # I to J
        # save image
        img = fI
        title = f'{I.space}_{I.name}_to_{target_space}'
        I_to_J_imgs = os.path.join(out, f'{target_space}/{I.space}_to_{target_space}/images/')                
        os.makedirs(I_to_J_imgs, exist_ok=True)
        emlddmm.write_vtk_data(os.path.join(I_to_J_imgs, title + '.vtk'), xJ, img, title)
        # save displacement
        disp = (fXJ - XJ)[None]
        title = f'{I.space}_to_{target_space}_displacement'
        I_to_J_transforms = os.path.join(out, f'{target_space}/{I.space}_to_{target_space}/transforms/')
        os.makedirs(I_to_J_transforms, exist_ok=True)
        emlddmm.write_vtk_data(os.path.join(I_to_J_transforms, title + '.vtk'), xJ, disp, title)
        # save determinant of jacobian
        dv = [(x[1]-x[0]) for x in xJ]
        jac = jacobian(fXJ, dv)
        detjac = np.linalg.det(jac)[None]
        title = f'{I.space}_{I.name}_to_{target_space}_detjac'
        emlddmm.write_vtk_data(os.path.join(I_to_J_transforms, title + '.vtk'), xJ, detjac, title) 
[docs]def run_registrations(reg_list):
    """ Run Registrations
    Runs a sequence of registrations given by reg_list. Saves transforms, qc images, reconstructed images,
    displacement fields, and determinant of Jacobian of displacements. Also builds and writes out the transform graph. 
    Parameters
    ----------
    reg_list : list of dicts
        each dict in reg_list specifies the source image path, target image path,
        source and target space names, registration configuration settings, and output directory.
  
    
    Returns
    -------
    reg_graph : emlddmm graph
    Example
    -------
    >>> reg_list = [{'registration':[['CCF','average_template_50'],['MRI','masked']],
                     'source': '/path/to/average_template_50.vtk',
                     'target': '/path/to/HR_NIHxCSHL_50um_14T_M1_masked.vtk',
                     'config': '/path/to/configMD816_MR_to_CCF.json',
                     'output': 'outputs/example_output'},
                    {'registration':[['MRI','masked'], ['HIST','Nissl']],
                     'source': '/path/to/HR_NIHxCSHL_50um_14T_M1_masked.vtk',
                     'target': '/path/to/MD816_STIF',
                     'config': '/path/to/configMD816_Nissl_to_MR.json',
                     'output': 'outputs/example_output'}]
    >>> run_registrations(reg_list)
    """
    # initialize graph
    # note from daniel
    # for 2D registration, the spaces will just say "histology".  Is that good enough?
    graph = Graph()
    for i in reg_list:
            for j in [i['registration'][0][0], i['registration'][1][0]]: # for src and target space names in each registration
                if j not in graph.spaces: 
                    print(f'adding space {j} to graph')
                    graph.add_space(j)
    graph.adj = [{} for i in range(len(graph.spaces))]
    # perform registrations
    for r in reg_list:
        source = r['source']
        target = r['target']
        registration = r['registration']
        config = r['config']
        output_dir = r['output']
        print(f"registering {source} to {target}")
        with open(config) as f:
            config = json.load(f)
        I = emlddmm.Image(space=registration[0][0], name=registration[0][1], fpath=source)
        print(f'Source I shape {I.data.shape}')
        if I.title == 'slice_dataset': # if series to series, both images must share the same coordinate grid
            J = emlddmm.Image(space=registration[1][0], name=registration[1][1], fpath=target, mask=True, x=I.x)
        else:
            J = emlddmm.Image(space=registration[1][0], name=registration[1][1], fpath=target, mask=True)
        print(f'Target J shape {J.data.shape}')
        # add domains to graph before downsampling for later
        graph.spaces[I.space][1] = I.x
        graph.spaces[J.space][1] = J.x
        # initial downsampling
        downIs = config['downI']
        downJs = config['downJ']
        mindownI = np.min(np.array(downIs),0)
        mindownJ = np.min(np.array(downJs),0)
        I.x, I.data, I.mask = I.downsample(mindownI)
        J.x, J.data, J.mask = J.downsample(mindownJ)
        downIs = [ list((np.array(d)/mindownI).astype(int)) for d in downIs]
        downJs = [ list((np.array(d)/mindownJ).astype(int)) for d in downJs]
        # update our config variable
        config['downI'] = downIs
        config['downJ'] = downJs
        # registration
        output = emlddmm.emlddmm_multiscale(I=I.data,xI=[I.x],J=J.data,xJ=[J.x],W0=J.mask,full_outputs=False,**config)
        '''
        for series to series:
            1) Save rigid transforms in {source}_input/{target}_to_{source}_input
            2) Save  qc source input and target to source_input in {source}_input/{target}_to_{source}_input/qc/
            3) Save qc target input and source to target_input in {target}_input/{source}_to_{target}_input/qc/
        for volume (source) to series (target):
            1) Save rigid transforms (R) in {target}_registered/{target}_input_to_registered/transforms/
            2) A.txt and velocity.vtk map points in source to match target, which are
            used to resmaple target images in source space. Save them in {source}/{target}_registered_to_{source}/transforms/
            3) Save qc source original and target to source in {source}/{target}_registered_to_{source}/qc/
            4) Save qc target_input and source to target_input in {target}_input/{source}_to_{target}_input/qc/
            5) Save qc target_registered and source to target_registered in {target}_registered/{source}_to_{target}_registered/qc/
        for volume to volume:
            1) Save A.txt and velocity.vtk in {source}/{target}_to_{source}/transforms/
        '''
        #print('about to write transformation outputs')
        emlddmm.write_transform_outputs(output_dir, output[-1], I, J)
        # TODO: check if there are annotations in this space, if so add them below
        emlddmm.write_qc_outputs(output_dir, output[-1], I, J)
        A = emlddmm.Transform(output[-1]['A'], 'f')
        Ai = emlddmm.Transform(output[-1]['A'], 'b')
        xv = output[-1]['xv']
        v = output[-1]['v']
        phi = emlddmm.Transform(v, direction='f', domain=xv)
        phii = emlddmm.Transform(v, direction='b', domain=xv)
        if 'A2d' in output[-1]:
            A2d = emlddmm.Transform(output[-1]['A2d'], 'f')
            A2di = emlddmm.Transform(output[-1]['A2d'], 'b')
            # check if this is a 3D to 2D map
            if J.title == 'slice_dataset' and I.title != 'slice_dataset':
                # NOTE from daniel
                # what I want to do here is create a new space, called nissl_registered
                # when I define the registered space I'll have to define the coordinates, which will be shifted using the mean shift
                # and add two sets of edges one containing only the A2d
                print('This is a 3D to 2D map')
                print(f'From space {I.space} to space {J.space}')
                print(f'Adding a new registered space {J.space}_registered')
                registered_space = J.space+'_registered'
                # if the mean translation from registered to input
                # is 5 pixels up
                # then we want the sample points in registered space to be 5 pixels down
                mean_translation = torch.mean(A2d.data[:,:2,-1], dim=0).clone().detach().cpu().numpy()
                
                #print(f'calculated mean translation {mean_translation}')
                # note we use the same z coordinate (J.x[0])
                registered_x = [J.x[0], J.x[1] - mean_translation[0], J.x[2] - mean_translation[1]]
                
                graph.add_space(registered_space)            
                graph.spaces[registered_space][1] = registered_x
                # NOTE: I must append an empty dictionary to the adjacency            
                graph.adj.append({})
                #print('adding an edge from I space to registered space')
                graph.add_edge([phi, A], I.space, registered_space)
                #print('adding an edge from registered space to J space') 
                graph.add_edge([A2d], registered_space, J.space) # this one is giving an error
                #print('adding an edge from registered space to I space')
                graph.add_edge([Ai,phii],registered_space,I.space)
                #print('adding an edge from J space to registered space')
                graph.add_edge([A2di],J.space,registered_space)
            else:
                #print('This is a 2d to 2d map, using only 2D transforms')
                graph.add_edge([A2d], I.space, J.space)
                graph.add_edge([A2di], J.space, I.space)
            
        else:    
            graph.add_edge([phi, A], I.space, J.space)
            graph.add_edge([Ai, phii], J.space, I.space)
    return graph 
[docs]def main():
    """ Main
    Main function for parsing input arguments, calculating registrations and applying transformations.
    Example
    -------
    $ python transformation_graph.py --infile GDMInput.json
    """
    help_string = "Arg parser looks for one argument, \'--infile\', which is a JSON file with the following entries: \n\
1) \"space_image_path\": a list of lists, each containing the space name, image name, and path to an image or image series. (Required)\n\
2) \"registrations\": a list of lists, each containing two space-image pairs to be registered. e.g. [[[\"HIST\", \"nissl\"], [\"MRI\", \"masked\"]],\n\
                                                                                                    [[\"MRI\", \"masked\"], [\"CCF\", \"average_template_50\"]],\n\
                                                                                                    [[\"MRI\", \"masked\"], [\"CT\", \"masked\"]]]\n\
                    If registrations were previously computed, this argument may be omitted, and the \"graph\" argument can be used to perform additional reconstructions.\n\
3) \"configs\": list of paths to registration config JSON files, the order of which corresponds to the order of registrations listed in the previous value. (Required if computing registrations)\n\
4) \"output\": output directory which will be the output hierarchy root. If none is given, it is set to the current directory.\n\
5) \"transforms\": transforms to apply after they are computed from registration. Only necessary if \"transform_all\" is False. Format is the same as \"registrations\".\n\
6) \"transform_all\": bool. Performs all possible reconstructions given the transform graph.\n\
7) \"graph\": path to \"graph.p\" pickle file saved to output after performing registrations. (Only required for reconstructing images from previously computed registrations)"
    parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter)
    parser.add_argument('--infile', nargs=1,
                        help=help_string,
                        type=argparse.FileType('r'))
    arguments = parser.parse_args()
    input_dict = json.load(arguments.infile[0])
    output = input_dict["output"] if "output" in input_dict else os.getcwd()
    if not os.path.exists(output):
        os.makedirs(output)
    # save out the json file used for input
    with open(os.path.join(output, "infile.json"), "w") as f:
        json.dump(input_dict, f, indent="")
    # get space_image_path
    try:
        space_image_path = input_dict["space_image_path"]
    except KeyError:
        print("space_image_path is a required argument. It is a list of images,\
             with each image being a list of the format: [\"space name\", \"image name\", \"image path\"]")
    # convert space_image_path to dictionary of dictionaries. (image_name:path key-values in a dict of space:img key-values)
    sip = {} # space-image-path dictionary
    for i in range(len(space_image_path)):
        if not space_image_path[i][0] in sip:
            sip[space_image_path[i][0]] = {}
        new_img = {space_image_path[i][1]: space_image_path[i][2]}
        sip[space_image_path[i][0]].update(new_img)
    # compute registrations
    if "registrations" in input_dict:
        try:
            configs = input_dict["configs"] # if registrations then there must also be configs
        except KeyError:
            ("configs must be included in input with registrations. configs is a list of full paths to JSON registration configuration files.")
        registrations = input_dict["registrations"]
        reg_list = [] # a list of dicts specifying inputs for each registration to perform
        for i in range(len(registrations)):
            src_space = registrations[i][0][0]
            src_img = registrations[i][0][1]
            src_path = sip[src_space][src_img]
            target_space = registrations[i][1][0]
            target_img = registrations[i][1][1]
            target_path = sip[target_space][target_img]
            reg_list.append({'registration': registrations[i], # registration format [[src_space, src_img], [target_space, target_img]]
                            'source': src_path,
                            'target': target_path,
                            'config': configs[i],
                            'output': output})
        print('registration list: ', reg_list, '\n')
        print('running registrations...')
        graph = run_registrations(reg_list)
        # save graph. Note: if a graph was supplied as an argument, it will be merged with the new one before saving.
        if "graph" in input_dict:
            with open(input_dict["graph"], 'rb') as f:
                tmp_graph = pickle.load(f)
            graph  = graph.merge(tmp_graph)
        with open(os.path.join(output, 'graph.p'), 'wb') as f:
            pickle.dump(graph, f)
    elif "graph" in input_dict: # if we do not specify registrations, but we do specify a graph
        with open(input_dict["graph"], 'rb') as f:
            graph = pickle.load(f)
            #print(graph.adj,graph.spaces)
    # daniel says, before we transform, we should update the sip    
    # importantly one space does not have an image in it! (the registered space)
    #print('printing space image path')
    #print(sip)
    for space in graph.spaces:
        #print(space)
        if space not in sip:
            sip[space] = {} # these will be the registered spaces
    
    
    if "transform_all" in input_dict and (input_dict["transform_all"] == True or input_dict['transform_all'].lower() == 'true'):
        if DEBUG: 
            print('Starting transform_all')
            
        for src_space in sip:
            if DEBUG:
                print(f'starting to transform from source {src_space}')
            # now what if there are no images in this space, we still want to output transforms
            images_to_iterate = sip[src_space]            
            if not images_to_iterate:
                if DEBUG:
                    print('*'*80)
                    print('No images to iterate over, adding a None')
                images_to_iterate = [None]                
            for src_image in images_to_iterate:
                if DEBUG:
                    print(f'starting to transform from source {src_space} image {src_image}')
                
                if src_image is not None:
                    src_path = sip[src_space][src_image]
                    I = emlddmm.Image(src_space, src_image, src_path, x=graph.spaces[src_space][1])
                else:
                    # we need to get a dummy image
                    if DEBUG:
                        print('No image here, getting a dummy image')
                    # give it a None for path
                    source_space_unregistered = src_space.replace('_registered','')                        
                    source_image = list(sip[source_space_unregistered].keys())[0] # get the first image. This is just to get file names if it is a series.
                    source_path = sip[source_space_unregistered][source_image]
                    I = emlddmm.Image(src_space, src_image, src_path, x=graph.spaces[src_space][1])
                    I.data = None
                    I.title = 'slice_dataset'
                    #I.x = graph.spaces[src_space][1] # shouldn't be stricly necessary
                    
                # reconstruct in every other space
                for target_space in [n for n in sip if n != src_space]:
                    print(f'starting to transform from source {src_space} image {src_image} to target space {target_space}')
                    if sip[target_space]: 
                        # if this is not an empty dictionary
                        # in the registered space it will be an empty dictionary
                        target_image = list(sip[target_space].keys())[0] # get the first image. This is just to get file names if it is a series.
                        target_path = sip[target_space][target_image]
                    else:
                        if DEBUG:
                            print('*'*80)
                            print(f'hi registered space {target_space}')
                        # in this case we still need the fnames for naming the outputs
                        target_space_unregistered = target_space.replace('_registered','')
                        #print(f'looking at this space {target_space_unregistered}')
                        target_image = list(sip[target_space_unregistered].keys())[0] # get the first iasdfmage. This is just to get file names if it is a series.
                        target_path = sip[target_space_unregistered][target_image]
                    if DEBUG:
                        print('about to start graph_reconstruct')
                    if os.path.splitext(target_path)[-1] == '': # this means a slice dataset
                        fnames = emlddmm.fnames(target_path)
                        graph_reconstruct(graph, output, I, target_space, target_fnames=fnames)
                    else:
                        graph_reconstruct(graph, output, I, target_space)
        
            
    # I still need to work on transforms to make it compatible with the above
    # also if the nissl space is mentioned
    elif "transforms" in input_dict:
        raise Exception('List of transformations not currently supported, only transform_all=True. TODO')
        transforms = input_dict["transforms"]
        # this requires a graph which can be output from run_registrations or included in input json
        if "graph" in input_dict:
            with open(input_dict["graph"], 'rb') as f:
                graph = pickle.load(f)
        assert "graph" in locals(), "\"graph\" argument is required when only applying new transforms."
        for t in transforms:
            I = emlddmm.Image(t[0][0], t[0][1], sip[t[0][0]][t[0][1]], x=graph.spaces[t[0][0]][1])
            target_space = t[1][0]
            target_image = t[1][1]
            target_path = sip[target_space][target_image]
            if os.path.splitext(target_path)[-1] == '':
                fnames = emlddmm.fnames(target_path)
                graph_reconstruct(graph, output, I, target_space, target_fnames=fnames)
            else:
                graph_reconstruct(graph, output, I, target_space)
    if "registered_qc" in input_dict and input_dict['registered_qc']:
        # this will be a special flag
        # TODO: after this is all done, we should get a special qc in registered space
        # now that all the data is written out, we can just read it
        # I can implement it here and move it somewhere later
        # the idea is we find the registered space, 
        # We load all data (atlas nissl fluoro) for each slice
        # we load labels and render them as curves
        print('In registered qc')
        # find the registered space from the graph
        
        
        print(input_dict)
        print('TODO: Registered QC is in progress')
        
        
    return 
if __name__ == "__main__":
    main()