Source code for AFQ.definitions.mapping

import nibabel as nib
import numpy as np
import logging
from time import time
import os.path as op

from AFQ.definitions.utils import Definition, find_file
from dipy.align import syn_registration, affine_registration
import AFQ.registration as reg
import AFQ.data.s3bids as afs
from AFQ.tasks.utils import get_fname

from dipy.align.imaffine import AffineMap

try:
    from fsl.data.image import Image
    from fsl.transform.fnirt import readFnirt
    from fsl.transform.nonlinear import applyDeformation
    has_fslpy = True
except ModuleNotFoundError:
    has_fslpy = False

try:
    import h5py
    has_h5py = True
except ModuleNotFoundError:
    has_h5py = False

__all__ = ["FnirtMap", "SynMap", "SlrMap", "AffMap"]


logger = logging.getLogger('AFQ.definitions.mapping')


# For map defintions, get_for_subses should return only the mapping
# Where the mapping has transform and transform_inverse functions
# which each accept data, **kwargs


[docs]class FnirtMap(Definition): """ Use an existing FNIRT map. Expects a warp file and an image file for each subject / session; image file is used as src space for warp. Parameters ---------- warp_path : str, optional path to file to get warp from. Use this or warp_suffix. Default: None space_path : str, optional path to file to get warp from. Use this or space_suffix. Default: None warp_suffix : str, optional suffix to pass to bids_layout.get() to identify the warp file. Default: None space_suffix : str, optional suffix to pass to bids_layout.get() to identify the space file. Default: None warp_filters : str, optional Additional filters to pass to bids_layout.get() to identify the warp file. Default: {} space_filters : str, optional Additional filters to pass to bids_layout.get() to identify the space file. Default: {} Examples -------- fnirt_map = FnirtMap( warp_suffix="warp", space_suffix="MNI", warp_filters={"scope": "TBSS"}, space_filters={"scope": "TBSS"}) api.GroupAFQ(mapping=fnirt_map) """ def __init__(self, warp_path=None, space_path=None, warp_suffix=None, space_suffix=None, warp_filters={}, space_filters={}): if not has_fslpy: raise ImportError( "Please install fslpy if you want to use FnirtMap") if warp_path is None and warp_suffix is None: raise ValueError(( "One of `warp_path` or `warp_suffix` should be set " "to a value other than None.")) if space_path is None and space_suffix is None: raise ValueError( "One of space_path or space_suffix must not be None.") if warp_path is not None and space_path is None\ or space_path is not None and warp_path is None: raise ValueError(( "If passing a value for `warp_path`, " "you must also pass a value for `space_path`")) if warp_path is not None: self._from_path = True self.fnames = (warp_path, space_path) else: self._from_path = False self.warp_suffix = warp_suffix self.warp_filters = warp_filters self.space_suffix = space_suffix self.space_filters = space_filters self.fnames = {}
[docs] def find_path(self, bids_layout, from_path, subject, session): if self._from_path: return if session not in self.fnames: self.fnames[session] = {} nearest_warp = find_file( bids_layout, from_path, self.warp_filters, self.warp_suffix, session, subject) nearest_space = find_file( bids_layout, from_path, self.space_filters, self.space_suffix, session, subject) self.fnames[session][subject] = (nearest_warp, nearest_space)
[docs] def get_for_subses(self, base_fname, dwi, bids_info, reg_subject, reg_template): if self._from_path: nearest_warp, nearest_space = self.fnames else: nearest_warp, nearest_space = self.fnames[ bids_info['session']][bids_info['subject']] our_templ = reg_template subj = Image(dwi) their_templ = Image(nearest_space) warp = readFnirt(nearest_warp, their_templ, subj) return ConformedFnirtMapping(warp, our_templ.affine)
class ConformedFnirtMapping(): """ ConformedFnirtMapping which matches the generic mapping API. """ def __init__(self, warp, ref_affine): self.ref_affine = ref_affine self.warp = warp def transform_inverse(self, data, **kwargs): data_img = Image(nib.Nifti1Image( data.astype(np.float32), self.ref_affine)) return np.asarray(applyDeformation(data_img, self.warp).data) def transform(self, data, **kwargs): raise NotImplementedError( "Fnirt based mappings can currently" + " only transform from template to subject space") class IdentityMap(Definition): """ Does not perform any transformations from MNI to subject where pyAFQ normally would. Examples -------- my_example_mapping = IdentityMap() api.GroupAFQ(mapping=my_example_mapping) """ def __init__(self): pass def find_path(self, bids_layout, from_path, subject, session): pass def get_for_subses(self, base_fname, dwi, bids_info, reg_subject, reg_template): return ConformedAffineMapping( np.identity(4), domain_grid_shape=reg.reduce_shape( reg_subject.shape), domain_grid2world=reg_subject.affine, codomain_grid_shape=reg.reduce_shape( reg_template.shape), codomain_grid2world=reg_template.affine) class ItkMap(Definition): """ Use an existing Itk map (e.g., from ANTS). Expects the warp file from MNI to T1. Parameters ---------- warp_path : str, optional path to file to get warp from. Use this or warp_suffix. Default: None warp_suffix : str, optional suffix to pass to bids_layout.get() to identify the warp file. warp_filters : str, optional Additional filters to pass to bids_layout.get() to identify the warp file. Default: {} Examples -------- itk_map = ItkMap( warp_suffix="xfm", warp_filters={ "scope": "qsiprep", "from": "MNI152NLin2009cAsym", "to": "T1w"}) api.GroupAFQ(mapping=itk_map) """ def __init__(self, warp_path=None, warp_suffix=None, warp_filters={}): if not has_h5py: raise ImportError( "Please install h5py if you want to use ItkMap") if warp_path is None and warp_suffix is None: raise ValueError(( "One of `warp_path` or `warp_suffix` should be set " "to a value other than None.")) if warp_path is not None: self._from_path = True self.fname = warp_path else: self._from_path = False self.suffix = warp_suffix self.filters = warp_filters self.fnames = {} def find_path(self, bids_layout, from_path, subject, session): if self._from_path: return if session not in self.fnames: self.fnames[session] = {} self.fnames[session][subject] = find_file( bids_layout, from_path, self.warp_filters, self.warp_suffix, session, subject, extension="h5") def get_for_subses(self, base_fname, dwi, bids_info, reg_subject, reg_template): if self._from_path: nearest_warp = self.fname else: nearest_warp = self.fnames[ bids_info['session']][bids_info['subject']] warp_f5 = h5py.File(nearest_warp) their_shape = np.asarray(warp_f5["TransformGroup"]['1'][ 'TransformFixedParameters'], dtype=int)[:3] our_shape = reg_template.get_fdata().shape if (our_shape != their_shape).any(): raise ValueError(( f"The shape of your ITK mapping ({their_shape})" f" is not the same as your template for registration" f" ({our_shape})")) their_forward = np.asarray(warp_f5["TransformGroup"]['1'][ 'TransformParameters']).reshape([*their_shape, 3]) their_disp = np.zeros((*their_shape, 3, 2)) their_disp[..., 0] = their_forward their_disp = nib.Nifti1Image( their_disp, reg_template.affine) their_prealign = np.zeros((4, 4)) their_prealign[:3, :3] = np.asarray(warp_f5["TransformGroup"]["2"][ "TransformParameters"])[:9].reshape((3, 3)) their_prealign[:3, 3] = np.asarray(warp_f5["TransformGroup"]["2"][ "TransformParameters"])[9:] their_prealign[3, 3] = 1.0 warp_f5.close() return reg.read_mapping( their_disp, dwi, reg_template, prealign=their_prealign) class GeneratedMapMixin(object): """ Helper Class Useful for maps that are generated by pyAFQ """ def get_fnames(self, extension, base_fname): mapping_file = get_fname( base_fname, '_mapping_from-DWI_to_MNI_xfm') meta_fname = get_fname(base_fname, '_mapping_reg') mapping_file = mapping_file + extension meta_fname = f'{meta_fname}.json' return mapping_file, meta_fname def prealign(self, base_fname, reg_subject, reg_template, save=True): prealign_file = get_fname( base_fname, '_prealign_from-DWI_to-MNI_xfm.npy') if not op.exists(prealign_file): start_time = time() _, aff = affine_registration( reg_subject, reg_template, **self.affine_kwargs) meta = dict( type="rigid", timing=time() - start_time) if not save: return aff logger.info(f"Saving {prealign_file}") np.save(prealign_file, aff) meta_fname = get_fname( base_fname, '_prealign_from-DWI_to-MNI_xfm.json') afs.write_json(meta_fname, meta) return prealign_file if save else np.load(prealign_file) def get_for_subses(self, base_fname, dwi, bids_info, reg_subject, reg_template, subject_sls=None, template_sls=None): mapping_file, meta_fname = self.get_fnames( self.extension, base_fname) if self.use_prealign: reg_prealign = np.load(self.prealign( base_fname, reg_subject, reg_template)) else: reg_prealign = None if not op.exists(mapping_file): start_time = time() mapping = self.gen_mapping( base_fname, reg_subject, reg_template, subject_sls, template_sls, reg_prealign) total_time = time() - start_time logger.info(f"Saving {mapping_file}") reg.write_mapping(mapping, mapping_file) meta = dict( type="displacementfield", timing=total_time) afs.write_json(meta_fname, meta) reg_prealign_inv = np.linalg.inv(reg_prealign) if self.use_prealign\ else None mapping = reg.read_mapping( mapping_file, dwi, reg_template, prealign=reg_prealign_inv) return mapping
[docs]class SynMap(GeneratedMapMixin, Definition): """ Calculate a Syn registration for each subject/session using reg_subject and reg_template. Parameters ---------- use_prealign : bool Whether to perform a linear pre-registration. Default: True affine_kwargs : dictionary, optional Parameters to pass to affine_registration in dipy.align, which does the linear pre-alignment. Only used if use_prealign is True. Default: {} syn_kwargs : dictionary, optional Parameters to pass to syn_registration in dipy.align, which does the SyN alignment. Default: {} Examples -------- api.GroupAFQ(mapping=SynMap()) """ def __init__(self, use_prealign=True, affine_kwargs={}, syn_kwargs={}): self.use_prealign = use_prealign self.affine_kwargs = affine_kwargs self.syn_kwargs = syn_kwargs self.extension = ".nii.gz"
[docs] def find_path(self, bids_layout, from_path, subject, session): pass
[docs] def gen_mapping(self, base_fname, reg_subject, reg_template, subject_sls, template_sls, reg_prealign): _, mapping = syn_registration( reg_subject.get_fdata(), reg_template.get_fdata(), moving_affine=reg_subject.affine, static_affine=reg_template.affine, prealign=reg_prealign, **self.syn_kwargs) if self.use_prealign: mapping.codomain_world2grid = np.linalg.inv(reg_prealign) return mapping
[docs]class SlrMap(GeneratedMapMixin, Definition): """ Calculate a SLR registration for each subject/session using reg_subject and reg_template. slr_kwargs : dictionary, optional Parameters to pass to whole_brain_slr in dipy, which does the SLR alignment. Default: {} Examples -------- api.GroupAFQ(mapping=SlrMap()) """ def __init__(self, slr_kwargs={}): self.slr_kwargs = {} self.use_prealign = False self.extension = ".npy"
[docs] def find_path(self, bids_layout, from_path, subject, session): pass
[docs] def gen_mapping(self, base_fname, reg_template, reg_subject, subject_sls, template_sls, reg_prealign): return reg.slr_registration( subject_sls, template_sls, moving_affine=reg_subject.affine, moving_shape=reg_subject.shape, static_affine=reg_template.affine, static_shape=reg_template.shape, **self.slr_kwargs)
[docs]class AffMap(GeneratedMapMixin, Definition): """ Calculate an affine registration for each subject/session using reg_subject and reg_template. affine_kwargs : dictionary, optional Parameters to pass to affine_registration in dipy.align, which does the linear pre-alignment. Default: {} Examples -------- api.GroupAFQ(mapping=AffMap()) """ def __init__(self, affine_kwargs={}): self.use_prealign = False self.affine_kwargs = affine_kwargs self.extension = ".npy"
[docs] def find_path(self, bids_layout, from_path, subject, session): pass
[docs] def gen_mapping(self, base_fname, reg_subject, reg_template, subject_sls, template_sls, reg_prealign): return ConformedAffineMapping( np.linalg.inv(self.prealign( base_fname, reg_subject, reg_template, save=False)), domain_grid_shape=reg.reduce_shape( reg_subject.shape), domain_grid2world=reg_subject.affine, codomain_grid_shape=reg.reduce_shape( reg_template.shape), codomain_grid2world=reg_template.affine)
class ConformedAffineMapping(AffineMap): """ Modifies AffineMap API to match DiffeomorphicMap API. Important for SLR maps API to be indistinguishable from SYN maps API. """ def transform(self, *args, interpolation='linear', **kwargs): kwargs['interp'] = interpolation return super().transform_inverse(*args, **kwargs) def transform_inverse(self, *args, interpolation='linear', **kwargs): kwargs['interp'] = interpolation return super().transform(*args, **kwargs)